diff --git a/lib/beacon/beacon.go b/lib/beacon/beacon.go index 47a28637..b7436e17 100644 --- a/lib/beacon/beacon.go +++ b/lib/beacon/beacon.go @@ -8,7 +8,6 @@ package beacon import ( "net" - stdsync "sync" "github.com/thejerf/suture" ) @@ -24,21 +23,3 @@ type Interface interface { Recv() ([]byte, net.Addr) Error() error } - -type errorHolder struct { - err error - mut stdsync.Mutex // uses stdlib sync as I want this to be trivially embeddable, and there is no risk of blocking -} - -func (e *errorHolder) setError(err error) { - e.mut.Lock() - e.err = err - e.mut.Unlock() -} - -func (e *errorHolder) Error() error { - e.mut.Lock() - err := e.err - e.mut.Unlock() - return err -} diff --git a/lib/beacon/broadcast.go b/lib/beacon/broadcast.go index d41e9ad1..c5580f31 100644 --- a/lib/beacon/broadcast.go +++ b/lib/beacon/broadcast.go @@ -11,8 +11,9 @@ import ( "net" "time" - "github.com/syncthing/syncthing/lib/sync" "github.com/thejerf/suture" + + "github.com/syncthing/syncthing/lib/util" ) type Broadcast struct { @@ -44,16 +45,16 @@ func NewBroadcast(port int) *Broadcast { } b.br = &broadcastReader{ - port: port, - outbox: b.outbox, - connMut: sync.NewMutex(), + port: port, + outbox: b.outbox, } + b.br.ServiceWithError = util.AsServiceWithError(b.br.serve) b.Add(b.br) b.bw = &broadcastWriter{ - port: port, - inbox: b.inbox, - connMut: sync.NewMutex(), + port: port, + inbox: b.inbox, } + b.bw.ServiceWithError = util.AsServiceWithError(b.bw.serve) b.Add(b.bw) return b @@ -76,34 +77,42 @@ func (b *Broadcast) Error() error { } type broadcastWriter struct { - port int - inbox chan []byte - conn *net.UDPConn - connMut sync.Mutex - errorHolder + util.ServiceWithError + port int + inbox chan []byte } -func (w *broadcastWriter) Serve() { +func (w *broadcastWriter) serve(stop chan struct{}) error { l.Debugln(w, "starting") defer l.Debugln(w, "stopping") conn, err := net.ListenUDP("udp4", nil) if err != nil { l.Debugln(err) - w.setError(err) - return + return err } - defer conn.Close() + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-stop: + case <-done: + } + conn.Close() + }() - w.connMut.Lock() - w.conn = conn - w.connMut.Unlock() + for { + var bs []byte + select { + case bs = <-w.inbox: + case <-stop: + return nil + } - for bs := range w.inbox { addrs, err := net.InterfaceAddrs() if err != nil { l.Debugln(err) - w.setError(err) + w.SetError(err) continue } @@ -134,14 +143,13 @@ func (w *broadcastWriter) Serve() { // Write timeouts should not happen. We treat it as a fatal // error on the socket. l.Debugln(err) - w.setError(err) - return + return err } if err != nil { // Some other error that we don't expect. Debug and continue. l.Debugln(err) - w.setError(err) + w.SetError(err) continue } @@ -150,57 +158,49 @@ func (w *broadcastWriter) Serve() { } if success > 0 { - w.setError(nil) + w.SetError(nil) } } } -func (w *broadcastWriter) Stop() { - w.connMut.Lock() - if w.conn != nil { - w.conn.Close() - } - w.connMut.Unlock() -} - func (w *broadcastWriter) String() string { return fmt.Sprintf("broadcastWriter@%p", w) } type broadcastReader struct { - port int - outbox chan recv - conn *net.UDPConn - connMut sync.Mutex - errorHolder + util.ServiceWithError + port int + outbox chan recv } -func (r *broadcastReader) Serve() { +func (r *broadcastReader) serve(stop chan struct{}) error { l.Debugln(r, "starting") defer l.Debugln(r, "stopping") conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: r.port}) if err != nil { l.Debugln(err) - r.setError(err) - return + return err } - defer conn.Close() - - r.connMut.Lock() - r.conn = conn - r.connMut.Unlock() + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-stop: + case <-done: + } + conn.Close() + }() bs := make([]byte, 65536) for { n, addr, err := conn.ReadFrom(bs) if err != nil { l.Debugln(err) - r.setError(err) - return + return err } - r.setError(nil) + r.SetError(nil) l.Debugf("recv %d bytes from %s", n, addr) @@ -208,19 +208,12 @@ func (r *broadcastReader) Serve() { copy(c, bs) select { case r.outbox <- recv{c, addr}: + case <-stop: + return nil default: l.Debugln("dropping message") } } - -} - -func (r *broadcastReader) Stop() { - r.connMut.Lock() - if r.conn != nil { - r.conn.Close() - } - r.connMut.Unlock() } func (r *broadcastReader) String() string { diff --git a/lib/beacon/multicast.go b/lib/beacon/multicast.go index befc5598..fe592f85 100644 --- a/lib/beacon/multicast.go +++ b/lib/beacon/multicast.go @@ -48,14 +48,14 @@ func NewMulticast(addr string) *Multicast { addr: addr, outbox: m.outbox, } - m.mr.Service = util.AsService(m.mr.serve) + m.mr.ServiceWithError = util.AsServiceWithError(m.mr.serve) m.Add(m.mr) m.mw = &multicastWriter{ addr: addr, inbox: m.inbox, } - m.mw.Service = util.AsService(m.mw.serve) + m.mw.ServiceWithError = util.AsServiceWithError(m.mw.serve) m.Add(m.mw) return m @@ -78,29 +78,35 @@ func (m *Multicast) Error() error { } type multicastWriter struct { - suture.Service + util.ServiceWithError addr string inbox <-chan []byte - errorHolder } -func (w *multicastWriter) serve(stop chan struct{}) { +func (w *multicastWriter) serve(stop chan struct{}) error { l.Debugln(w, "starting") defer l.Debugln(w, "stopping") gaddr, err := net.ResolveUDPAddr("udp6", w.addr) if err != nil { l.Debugln(err) - w.setError(err) - return + return err } conn, err := net.ListenPacket("udp6", ":0") if err != nil { l.Debugln(err) - w.setError(err) - return + return err } + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-stop: + case <-done: + } + conn.Close() + }() pconn := ipv6.NewPacketConn(conn) @@ -113,14 +119,13 @@ func (w *multicastWriter) serve(stop chan struct{}) { select { case bs = <-w.inbox: case <-stop: - return + return nil } intfs, err := net.Interfaces() if err != nil { l.Debugln(err) - w.setError(err) - return + return err } success := 0 @@ -132,7 +137,7 @@ func (w *multicastWriter) serve(stop chan struct{}) { if err != nil { l.Debugln(err, "on write to", gaddr, intf.Name) - w.setError(err) + w.SetError(err) continue } @@ -142,16 +147,13 @@ func (w *multicastWriter) serve(stop chan struct{}) { select { case <-stop: - return + return nil default: } } if success > 0 { - w.setError(nil) - } else { - l.Debugln(err) - w.setError(err) + w.SetError(nil) } } } @@ -161,35 +163,40 @@ func (w *multicastWriter) String() string { } type multicastReader struct { - suture.Service + util.ServiceWithError addr string outbox chan<- recv - errorHolder } -func (r *multicastReader) serve(stop chan struct{}) { +func (r *multicastReader) serve(stop chan struct{}) error { l.Debugln(r, "starting") defer l.Debugln(r, "stopping") gaddr, err := net.ResolveUDPAddr("udp6", r.addr) if err != nil { l.Debugln(err) - r.setError(err) - return + return err } conn, err := net.ListenPacket("udp6", r.addr) if err != nil { l.Debugln(err) - r.setError(err) - return + return err } + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-stop: + case <-done: + } + conn.Close() + }() intfs, err := net.Interfaces() if err != nil { l.Debugln(err) - r.setError(err) - return + return err } pconn := ipv6.NewPacketConn(conn) @@ -206,16 +213,20 @@ func (r *multicastReader) serve(stop chan struct{}) { if joined == 0 { l.Debugln("no multicast interfaces available") - r.setError(errors.New("no multicast interfaces available")) - return + return errors.New("no multicast interfaces available") } bs := make([]byte, 65536) for { + select { + case <-stop: + return nil + default: + } n, _, addr, err := pconn.ReadFrom(bs) if err != nil { l.Debugln(err) - r.setError(err) + r.SetError(err) continue } l.Debugf("recv %d bytes from %s", n, addr) @@ -224,8 +235,6 @@ func (r *multicastReader) serve(stop chan struct{}) { copy(c, bs) select { case r.outbox <- recv{c, addr}: - case <-stop: - return default: l.Debugln("dropping message") } diff --git a/lib/nat/registry.go b/lib/nat/registry.go index 91e18843..0c04e11a 100644 --- a/lib/nat/registry.go +++ b/lib/nat/registry.go @@ -19,7 +19,7 @@ func Register(provider DiscoverFunc) { providers = append(providers, provider) } -func discoverAll(renewal, timeout time.Duration) map[string]Device { +func discoverAll(renewal, timeout time.Duration, stop chan struct{}) map[string]Device { wg := &sync.WaitGroup{} wg.Add(len(providers)) @@ -28,20 +28,32 @@ func discoverAll(renewal, timeout time.Duration) map[string]Device { for _, discoverFunc := range providers { go func(f DiscoverFunc) { + defer wg.Done() for _, dev := range f(renewal, timeout) { - c <- dev + select { + case c <- dev: + case <-stop: + return + } } - wg.Done() }(discoverFunc) } nats := make(map[string]Device) go func() { - for dev := range c { - nats[dev.ID()] = dev + defer close(done) + for { + select { + case dev, ok := <-c: + if !ok { + return + } + nats[dev.ID()] = dev + case <-stop: + return + } } - close(done) }() wg.Wait() diff --git a/lib/nat/service.go b/lib/nat/service.go index 5beeea60..c564e246 100644 --- a/lib/nat/service.go +++ b/lib/nat/service.go @@ -14,17 +14,21 @@ import ( stdsync "sync" "time" + "github.com/thejerf/suture" + "github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/sync" + "github.com/syncthing/syncthing/lib/util" ) // Service runs a loop for discovery of IGDs (Internet Gateway Devices) and // setup/renewal of a port mapping. type Service struct { - id protocol.DeviceID - cfg config.Wrapper - stop chan struct{} + suture.Service + + id protocol.DeviceID + cfg config.Wrapper mappings []*Mapping timer *time.Timer @@ -32,27 +36,28 @@ type Service struct { } func NewService(id protocol.DeviceID, cfg config.Wrapper) *Service { - return &Service{ + s := &Service{ id: id, cfg: cfg, timer: time.NewTimer(0), mut: sync.NewRWMutex(), } + s.Service = util.AsService(s.serve) + return s } -func (s *Service) Serve() { +func (s *Service) serve(stop chan struct{}) { announce := stdsync.Once{} s.mut.Lock() s.timer.Reset(0) - s.stop = make(chan struct{}) s.mut.Unlock() for { select { case <-s.timer.C: - if found := s.process(); found != -1 { + if found := s.process(stop); found != -1 { announce.Do(func() { suffix := "s" if found == 1 { @@ -61,7 +66,7 @@ func (s *Service) Serve() { l.Infoln("Detected", found, "NAT service"+suffix) }) } - case <-s.stop: + case <-stop: s.timer.Stop() s.mut.RLock() for _, mapping := range s.mappings { @@ -73,7 +78,7 @@ func (s *Service) Serve() { } } -func (s *Service) process() int { +func (s *Service) process(stop chan struct{}) int { // toRenew are mappings which are due for renewal // toUpdate are the remaining mappings, which will only be updated if one of // the old IGDs has gone away, or a new IGD has appeared, but only if we @@ -115,25 +120,19 @@ func (s *Service) process() int { return -1 } - nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second) + nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second, stop) for _, mapping := range toRenew { - s.updateMapping(mapping, nats, true) + s.updateMapping(mapping, nats, true, stop) } for _, mapping := range toUpdate { - s.updateMapping(mapping, nats, false) + s.updateMapping(mapping, nats, false, stop) } return len(nats) } -func (s *Service) Stop() { - s.mut.RLock() - close(s.stop) - s.mut.RUnlock() -} - func (s *Service) NewMapping(protocol Protocol, ip net.IP, port int) *Mapping { mapping := &Mapping{ protocol: protocol, @@ -178,17 +177,17 @@ func (s *Service) RemoveMapping(mapping *Mapping) { // acquire mappings for natds which the mapping was unaware of before. // Optionally takes renew flag which indicates whether or not we should renew // mappings with existing natds -func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool) { +func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) { var added, removed []Address renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute mapping.expires = time.Now().Add(renewalTime) - newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew) + newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew, stop) added = append(added, newAdded...) removed = append(removed, newRemoved...) - newAdded, newRemoved = s.acquireNewMappings(mapping, nats) + newAdded, newRemoved = s.acquireNewMappings(mapping, nats, stop) added = append(added, newAdded...) removed = append(removed, newRemoved...) @@ -197,12 +196,18 @@ func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew } } -func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool) ([]Address, []Address) { +func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) ([]Address, []Address) { var added, removed []Address leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute for id, address := range mapping.addressMap() { + select { + case <-stop: + return nil, nil + default: + } + // Delete addresses for NATDevice's that do not exist anymore nat, ok := nats[id] if !ok { @@ -242,13 +247,19 @@ func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Devic return added, removed } -func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device) ([]Address, []Address) { +func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device, stop chan struct{}) ([]Address, []Address) { var added, removed []Address leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute addrMap := mapping.addressMap() for id, nat := range nats { + select { + case <-stop: + return nil, nil + default: + } + if _, ok := addrMap[id]; ok { continue } diff --git a/lib/relay/client/dynamic.go b/lib/relay/client/dynamic.go index 6a3751b3..f2ad0c40 100644 --- a/lib/relay/client/dynamic.go +++ b/lib/relay/client/dynamic.go @@ -69,15 +69,7 @@ func (c *dynamicClient) serve(stop chan struct{}) error { addrs = append(addrs, ruri.String()) } - defer func() { - c.mut.RLock() - if c.client != nil { - c.client.Stop() - } - c.mut.RUnlock() - }() - - for _, addr := range relayAddressesOrder(addrs) { + for _, addr := range relayAddressesOrder(addrs, stop) { select { case <-stop: l.Debugln(c, "stopping") @@ -104,6 +96,15 @@ func (c *dynamicClient) serve(stop chan struct{}) error { return fmt.Errorf("could not find a connectable relay") } +func (c *dynamicClient) Stop() { + c.mut.RLock() + if c.client != nil { + c.client.Stop() + } + c.mut.RUnlock() + c.commonClient.Stop() +} + func (c *dynamicClient) Error() error { c.mut.RLock() defer c.mut.RUnlock() @@ -147,7 +148,7 @@ type dynamicAnnouncement struct { // the closest 50ms, and puts them in buckets of 50ms latency ranges. Then // shuffles each bucket, and returns all addresses starting with the ones from // the lowest latency bucket, ending with the highest latency buceket. -func relayAddressesOrder(input []string) []string { +func relayAddressesOrder(input []string, stop chan struct{}) []string { buckets := make(map[int][]string) for _, relay := range input { @@ -159,6 +160,12 @@ func relayAddressesOrder(input []string) []string { id := int(latency/time.Millisecond) / 50 buckets[id] = append(buckets[id], relay) + + select { + case <-stop: + return nil + default: + } } var ids []int diff --git a/lib/stun/stun.go b/lib/stun/stun.go index 16ae0bdd..75e5ff04 100644 --- a/lib/stun/stun.go +++ b/lib/stun/stun.go @@ -109,8 +109,8 @@ func New(cfg config.Wrapper, subscriber Subscriber, conn net.PacketConn) (*Servi } func (s *Service) Stop() { - s.Service.Stop() _ = s.stunConn.Close() + s.Service.Stop() } func (s *Service) serve(stop chan struct{}) { @@ -163,7 +163,11 @@ func (s *Service) serve(stop chan struct{}) { // We failed to contact all provided stun servers or the nat is not punchable. // Chillout for a while. - time.Sleep(stunRetryInterval) + select { + case <-time.After(stunRetryInterval): + case <-stop: + return + } } } diff --git a/lib/util/utils.go b/lib/util/utils.go index 8044ddf5..68b738ab 100644 --- a/lib/util/utils.go +++ b/lib/util/utils.go @@ -187,6 +187,7 @@ func AsService(fn func(stop chan struct{})) suture.Service { type ServiceWithError interface { suture.Service Error() error + SetError(error) } // AsServiceWithError does the same as AsService, except that it keeps track @@ -244,3 +245,9 @@ func (s *service) Error() error { defer s.mut.Unlock() return s.err } + +func (s *service) SetError(err error) { + s.mut.Lock() + s.err = err + s.mut.Unlock() +}