用Go语言编写一个简单的WebSocket推送服务

Golang 作者:sasabebe 2025-04-16 05:51:41

推送服务实现

基本原理

server 启动以后会注册两个 handler。

websocketHandler 用于提供浏览器端发送 Upgrade 请求并升级为 WebSocket 连接。

pushHandler 用于提供外部推送端发送推送数据的请求。

浏览器首先连接 websocketHandler (默认地址为 ws://ip:port/ws)升级请求为 WebSocket 连接,当连接建立之后需要发送注册信息进行注册。这里注册信息中包含一个 token 信息。

server 会对提供的 token 进行验证并获取到相应的 userId(通常来说,一个 userId 可能同时关联许多 token),并保存维护好 token, userId 和 conn(连接)之间的关系。

立即学习“go语言免费学习笔记(深入)”;

推送端发送推送数据的请求到 pushHandler(默认地址为 ws://ip:port/push),请求中包含了 userId 字段和 message 字段。server 会根据 userId 获取到所有此时连接到该 server 的 conn,然后将 message 一一进行推送。

由于推送服务的实时性,推送的数据并没有也不需要进行缓存。

代码详解

我在此处会稍微讲述一下代码的基本构成,也顺便说说 Go 语言中一些常用的写法和模式(本人也是从其他语言转向 Go 语言,毕竟 Go 语言也相当年轻。所以有建议的话,敬请提出。)。

由于Go语言的发明人和一些主要维护者大都来自于 C/C++ 语言,所以 Go 语言的代码也更偏向于 C/C++ 系。

首先先看一下 Server 的结构:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

// Server defines parameters for running websocket server.

type Server struct {

    // Address for server to listen on

    Addr string

 

    // Path for websocket request, default "/ws".

    WSPath string

 

    // Path for push message, default "/push".

    PushPath string

 

    // Upgrader is for upgrade connection to websocket connection using

    // "github.com/gorilla/websocket".

    //

    // If Upgrader is nil, default upgrader will be used. Default upgrader is

    // set ReadBufferSize and WriteBufferSize to 1024, and CheckOrigin always

    // returns true.

    Upgrader *websocket.Upgrader

 

    // Check token if it's valid and return userID. If token is valid, userID

    // must be returned and ok should be true. Otherwise ok should be false.

    AuthToken func(token string) (userID string, ok bool)

 

    // Authorize push request. Message will be sent if it returns true,

    // otherwise the request will be discarded. Default nil and push request

    // will always be accepted.

    PushAuth func(r *http.Request) bool

 

    wh *websocketHandler

    ph *pushHandler

}

这里说一下 Upgrader *websocket.Upgrader,这是 gorilla/websocket 包的对象,它用来升级 HTTP 请求。

如果一个结构体参数过多,通常不建议直接初始化,而是使用它提供的 New 方法。这里是:

1

2

3

4

5

6

// NewServer creates a new Server.func NewServer(addr string) *Server {    return &Server{

        Addr:     addr,

        WSPath:   serverDefaultWSPath,

        PushPath: serverDefaultPushPath,

    }

}

这也是 Go 语言对外提供初始化方法的一种常见用法。

然后 Server 使用 ListenAndServe 方法启动并监听端口,与 http 包的使用类似:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

// ListenAndServe listens on the TCP network address and handle websocket

// request.

func (s *Server) ListenAndServe() error {

    b := &binder{

        userID2EventConnMap: make(map[string]*[]eventConn),

        connID2UserIDMap:    make(map[string]string),

    }

 

    // websocket request handler

    wh := websocketHandler{

        upgrader: defaultUpgrader,

        binder:   b,

    }

    if s.Upgrader != nil {

        wh.upgrader = s.Upgrader

    }

    if s.AuthToken != nil {

        wh.calcUserIDFunc = s.AuthToken

    }

    s.wh = &wh

    http.Handle(s.WSPath, s.wh)

 

    // push request handler

    ph := pushHandler{

        binder: b,

    }

    if s.PushAuth != nil {

        ph.authFunc = s.PushAuth

    }

    s.ph = &ph

    http.Handle(s.PushPath, s.ph)

 

    return http.ListenAndServe(s.Addr, nil)

}

这里我们生成了两个 Handler,分别为 websocketHandler 和 pushHandler。websocketHandler 负责与浏览器建立连接并传输数据,而 pushHandler 则处理推送端的请求。

可以看到,这里两个 Handler 都封装了一个 binder 对象。这个 binder 用于维护 token userID Conn 的关系:

1

2

3

4

5

6

7

8

9

10

// binder is defined to store the relation of userID and eventConn

type binder struct {

    mu sync.RWMutex

 

    // map stores key: userID and value of related slice of eventConn

    userID2EventConnMap map[string]*[]eventConn

 

    // map stores key: connID and value: userID

    connID2UserIDMap map[string]string

}

websocketHandler

具体看一下 websocketHandler 的实现。

1

2

3

4

5

6

7

8

9

10

11

12

// websocketHandler defines to handle websocket upgrade request.

type websocketHandler struct {

    // upgrader is used to upgrade request.

    upgrader *websocket.Upgrader

 

    // binder stores relations about websocket connection and userID.

    binder *binder

 

    // calcUserIDFunc defines to calculate userID by token. The userID will

    // be equal to token if this function is nil.

    calcUserIDFunc func(token string) (userID string, ok bool)

}

很简单的结构。websocketHandler 实现了 http.Handler 接口:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

// First try to upgrade connection to websocket. If success, connection will

// be kept until client send close message or server drop them.

func (wh *websocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

    wsConn, err := wh.upgrader.Upgrade(w, r, nil)

    if err != nil {

        return

    }

    defer wsConn.Close()

 

    // handle Websocket request

    conn := NewConn(wsConn)

    conn.AfterReadFunc = func(messageType int, r io.Reader) {

        var rm RegisterMessage

        decoder := json.NewDecoder(r)

        if err := decoder.Decode(&rm); err != nil {

            return

        }

 

        // calculate userID by token

        userID := rm.Token

        if wh.calcUserIDFunc != nil {

            uID, ok := wh.calcUserIDFunc(rm.Token)

            if !ok {

                return

            }

            userID = uID

        }

 

        // bind

        wh.binder.Bind(userID, rm.Event, conn)

    }

    conn.BeforeCloseFunc = func() {

        // unbind

        wh.binder.Unbind(conn)

    }

 

    conn.Listen()

}

首先将传入的 http.Request 转换为 websocket.Conn,再将其分装为我们自定义的一个 wserver.Conn(封装,或者说是组合,是 Go 语言的典型用法。记住,Go 语言没有继承,只有组合)。

然后设置了 Conn 的 AfterReadFunc 和 BeforeCloseFunc 方法,接着启动了 conn.Listen()。AfterReadFunc 意思是当 Conn 读取到数据后,尝试验证并根据 token 计算 userID,然乎 bind 注册绑定。BeforeCloseFunc 则为 Conn 关闭前进行解绑操作。

pushHandler

pushHandler 则容易理解。它解析请求然后推送数据:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

// Authorize if needed. Then decode the request and push message to each

// realted websocket connection.

func (s *pushHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

    if r.Method != http.MethodPost {

        w.WriteHeader(http.StatusMethodNotAllowed)

        return

    }

 

    // authorize

    if s.authFunc != nil {

        if ok := s.authFunc(r); !ok {

            w.WriteHeader(http.StatusUnauthorized)

            return

        }

    }

 

    // read request

    var pm PushMessage

    decoder := json.NewDecoder(r.Body)

    if err := decoder.Decode(&pm); err != nil {

        w.WriteHeader(http.StatusBadRequest)

        w.Write([]byte(ErrRequestIllegal.Error()))

        return

    }

 

    // validate the data

    if pm.UserID == "" || pm.Event == "" || pm.Message == "" {

        w.WriteHeader(http.StatusBadRequest)

        w.Write([]byte(ErrRequestIllegal.Error()))

        return

    }

 

    cnt, err := s.push(pm.UserID, pm.Event, pm.Message)

    if err != nil {

        w.WriteHeader(http.StatusInternalServerError)

        w.Write([]byte(err.Error()))

        return

    }

 

    result := strings.NewReader(fmt.Sprintf("message sent to %d clients", cnt))

    io.Copy(w, result)

}

Conn

Conn (此处指 wserver.Conn) 为 websocket.Conn 的包装。

 

// Conn wraps websocket.Conn with Conn. It defines to listen and read

// data from Conn.

type Conn struct {

    Conn *websocket.Conn

 

    AfterReadFunc   func(messageType int, r io.Reader)

    BeforeCloseFunc func()

 

    once   sync.Once

    id     string

    stopCh chan struct{}

}

最主要的方法为 Listen():

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

// Listen listens for receive data from websocket connection. It blocks

// until websocket connection is closed.

func (c *Conn) Listen() {

    c.Conn.SetCloseHandler(func(code int, text string) error {

        if c.BeforeCloseFunc != nil {

            c.BeforeCloseFunc()

        }

 

        if err := c.Close(); err != nil {

            log.Println(err)

        }

 

        message := websocket.FormatCloseMessage(code, "")

        c.Conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))

        return nil

    })

 

    // Keeps reading from Conn util get error.

ReadLoop:

    for {

        select {

        case <-c.stopCh:

            break ReadLoop

        default:

            messageType, r, err := c.Conn.NextReader()

            if err != nil {

                // TODO: handle read error maybe

                break ReadLoop

            }

 

            if c.AfterReadFunc != nil {

                c.AfterReadFunc(messageType, r)

            }

        }

    }

}

主要设置了当 websocket 连接关闭时的处理和不停地读取数据。

关注公众号:拾黑(shiheibook)了解更多

友情链接:

下软件就上简单下载站:https://www.jdsec.com/
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/

公众号 关注网络尖刀微信公众号
随时掌握互联网精彩
赞助链接