diff --git a/server/websocket.go b/server/websocket.go index a8ead8b1..1942fdab 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -53,6 +53,7 @@ type websocketChannel struct { requestHeader http.Header alive bool aliveLock sync.Mutex + addrDescs []string // subscribed address descriptors as strings } // WebsocketServer is a handle to websocket server @@ -722,7 +723,7 @@ func (s *WebsocketServer) unsubscribeNewTransaction(c *websocketChannel) (res in return &subscriptionResponse{false}, nil } -func (s *WebsocketServer) unmarshalAddresses(params []byte) ([]bchain.AddressDescriptor, error) { +func (s *WebsocketServer) unmarshalAddresses(params []byte) ([]string, error) { r := struct { Addresses []string `json:"addresses"` }{} @@ -730,38 +731,41 @@ func (s *WebsocketServer) unmarshalAddresses(params []byte) ([]bchain.AddressDes if err != nil { return nil, err } - rv := make([]bchain.AddressDescriptor, len(r.Addresses)) + rv := make([]string, len(r.Addresses)) for i, a := range r.Addresses { ad, err := s.chainParser.GetAddrDescFromAddress(a) if err != nil { return nil, err } - rv[i] = ad + rv[i] = string(ad) } return rv, nil } // unsubscribe addresses without addressSubscriptionsLock - can be called only from subscribeAddresses and unsubscribeAddresses func (s *WebsocketServer) doUnsubscribeAddresses(c *websocketChannel) { - for ads, sa := range s.addressSubscriptions { - for sc := range sa { - if sc == c { - delete(sa, c) + for _, ads := range c.addrDescs { + sa, e := s.addressSubscriptions[ads] + if e { + for sc := range sa { + if sc == c { + delete(sa, c) + } + } + if len(sa) == 0 { + delete(s.addressSubscriptions, ads) } } - if len(sa) == 0 { - delete(s.addressSubscriptions, ads) - } } + c.addrDescs = nil } -func (s *WebsocketServer) subscribeAddresses(c *websocketChannel, addrDesc []bchain.AddressDescriptor, req *websocketReq) (res interface{}, err error) { +func (s *WebsocketServer) subscribeAddresses(c *websocketChannel, addrDesc []string, req *websocketReq) (res interface{}, err error) { s.addressSubscriptionsLock.Lock() defer s.addressSubscriptionsLock.Unlock() // unsubscribe all previous subscriptions s.doUnsubscribeAddresses(c) - for i := range addrDesc { - ads := string(addrDesc[i]) + for _, ads := range addrDesc { as, ok := s.addressSubscriptions[ads] if !ok { as = make(map[*websocketChannel]string) @@ -769,6 +773,7 @@ func (s *WebsocketServer) subscribeAddresses(c *websocketChannel, addrDesc []bch } as[c] = req.ID } + c.addrDescs = addrDesc s.metrics.WebsocketSubscribes.With((common.Labels{"method": "subscribeAddresses"})).Set(float64(len(s.addressSubscriptions))) return &subscriptionResponse{true}, nil }