Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions hscontrol/notifier/notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,26 @@ func (n *Notifier) Close() {
n.closed = true
n.b.close()

for _, c := range n.nodes {
close(c)
// Close channels safely using the helper method
for nodeID, c := range n.nodes {
n.safeCloseChannel(nodeID, c)
}

// Clear node map after closing channels
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}

// safeCloseChannel closes a channel and panic recovers if already closed
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
log.Error().
Uint64("node.id", nodeID.Uint64()).
Any("recover", r).
Msg("recovered from panic when closing channel in Close()")
}
}()
close(c)
}

func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
Expand All @@ -90,7 +107,11 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
n.tracef(nodeID, "channel present, closing and replacing")
close(curr)
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(ch chan<- types.StateUpdate) {
n.safeCloseChannel(nodeID, ch)
}(curr)
}

n.nodes[nodeID] = c
Expand Down Expand Up @@ -161,6 +182,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false
}

// LikelyConnectedMap returns a thread safe map of connected nodes
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}
Expand Down
78 changes: 78 additions & 0 deletions hscontrol/notifier/notifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package notifier

import (
"context"
"fmt"
"math/rand"
"net/netip"
"sort"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -263,3 +266,78 @@ func TestBatcher(t *testing.T) {
})
}
}

// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
Tuning: types.Tuning{
NotifierSendTimeout: 1 * time.Second,
BatchChangeDelay: 1 * time.Second,
NodeMapSessionBufferedChanSize: 30,
},
}

notifier := NewNotifier(cfg)
defer notifier.Close()

nodeID := types.NodeID(1)
updateChan := make(chan types.StateUpdate, 10)

var wg sync.WaitGroup

// Number of goroutines to spawn for concurrent access
concurrentAccessors := 100
iterations := 100

// Add node to notifier
notifier.AddNode(nodeID, updateChan)

// Track errors
errChan := make(chan string, concurrentAccessors*iterations)

// Start goroutines to cause a race
wg.Add(concurrentAccessors)
for i := 0; i < concurrentAccessors; i++ {
go func(routineID int) {
defer wg.Done()

for j := 0; j < iterations; j++ {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}

// Small random delay to increase chance of races
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
}
}(i)
}

wg.Wait()
close(errChan)

// Collate errors
var errors []string
for err := range errChan {
errors = append(errors, err)
}

if len(errors) > 0 {
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
}
}
Loading