Go WebSocket

WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,允许服务器和客户端之间进行实时、双向的数据传输。与 HTTP 相比,WebSocket 提供了更低的延迟和更高效的通信。

Go 标准库没有直接提供 WebSocket 支持,最常用的是 gorilla/websocket 库。

下面是一个示例:

package main

import (
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "sync"
    "time"
    "unicode/utf8"

    "github.com/gorilla/websocket"
)

// 1. 基本 WebSocket 服务器
type WebSocketServer struct {
    upgrader websocket.Upgrader
    clients  map[*websocket.Conn]bool
    clientsMu sync.RWMutex
    broadcast chan []byte
}

func NewWebSocketServer() *WebSocketServer {
    return &WebSocketServer{
        upgrader: websocket.Upgrader{
            CheckOrigin: func(r *http.Request) bool {
                // 在生产环境中应该验证来源
                return true
            },
        },
        clients:   make(map[*websocket.Conn]bool),
        broadcast: make(chan []byte, 256),
    }
}

func (s *WebSocketServer) Start(port string) {
    http.HandleFunc("/", s.serveHome)
    http.HandleFunc("/ws", s.handleWebSocket)

    // 启动广播器
    go s.broadcastMessages()

    fmt.Printf("WebSocket 服务器启动在 http://localhost%s\n", port)
    fmt.Printf("访问 http://localhost%s 测试聊天室\n", port)
    log.Fatal(http.ListenAndServe(port, nil))
}

func (s *WebSocketServer) serveHome(w http.ResponseWriter, r *http.Request) {
    if r.URL.Path != "/" {
        http.Error(w, "Not found", http.StatusNotFound)
        return
    }
    if r.Method != "GET" {
        http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
        return
    }

    // 返回简单的 HTML 页面
    html := `
    <!DOCTYPE html>
    <html>
    <head>
        <title>Go WebSocket 聊天室</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            #chat { width: 500px; height: 300px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; }
            #message { width: 400px; padding: 5px; }
            button { padding: 5px 15px; }
            .system { color: #666; font-style: italic; }
            .user { color: #0066cc; font-weight: bold; }
            .message { margin: 5px 0; }
        </style>
    </head>
    <body>
        <h1>Go WebSocket 聊天室</h1>
        <div id="chat"></div>
        <input type="text" id="message" placeholder="输入消息...">
        <button onclick="sendMessage()">发送</button>
        <button onclick="clearChat()">清空</button>

        <script>
            let ws = new WebSocket('ws://' + window.location.host + '/ws');

            ws.onmessage = function(event) {
                const data = JSON.parse(event.data);
                const chat = document.getElementById('chat');

                let messageDiv = document.createElement('div');
                messageDiv.className = 'message';

                if (data.type === 'system') {
                    messageDiv.innerHTML = '<span class="system">系统: ' + data.message + '</span>';
                } else if (data.type === 'user_join') {
                    messageDiv.innerHTML = '<span class="system">' + data.message + '</span>';
                } else if (data.type === 'user_leave') {
                    messageDiv.innerHTML = '<span class="system">' + data.message + '</span>';
                } else {
                    messageDiv.innerHTML = '<span class="user">' + data.username + ':</span> ' + data.message;
                }

                chat.appendChild(messageDiv);
                chat.scrollTop = chat.scrollHeight;
            };

            ws.onclose = function() {
                alert('连接已断开');
            };

            function sendMessage() {
                const input = document.getElementById('message');
                const message = input.value.trim();

                if (message) {
                    ws.send(JSON.stringify({
                        type: 'chat',
                        message: message,
                        timestamp: new Date().toISOString()
                    }));
                    input.value = '';
                }
            }

            function clearChat() {
                document.getElementById('chat').innerHTML = '';
            }

            document.getElementById('message').addEventListener('keypress', function(e) {
                if (e.key === 'Enter') {
                    sendMessage();
                }
            });
        </script>
    </body>
    </html>
    `
    w.Header().Set("Content-Type", "text/html; charset=utf-8")
    w.Write([]byte(html))
}

func (s *WebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
    // 升级 HTTP 连接到 WebSocket
    conn, err := s.upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }
    defer conn.Close()

    // 注册客户端
    s.clientsMu.Lock()
    s.clients[conn] = true
    clientCount := len(s.clients)
    s.clientsMu.Unlock()

    // 通知其他用户有新用户加入
    joinMessage := map[string]interface{}{
        "type":    "user_join",
        "message": fmt.Sprintf("新用户加入,当前在线用户: %d", clientCount),
        "time":    time.Now().Format("15:04:05"),
    }
    s.broadcastToAll(joinMessage)

    // 发送欢迎消息给新用户
    welcomeMsg := map[string]interface{}{
        "type":    "system",
        "message": "欢迎加入聊天室!",
        "time":    time.Now().Format("15:04:05"),
    }
    conn.WriteJSON(welcomeMsg)

    log.Printf("新客户端连接,总连接数: %d", clientCount)

    // 处理消息
    for {
        var message map[string]interface{}
        err := conn.ReadJSON(&message)
        if err != nil {
            log.Printf("读取消息错误: %v", err)
            break
        }

        log.Printf("收到消息: %v", message)

        // 处理不同类型的消息
        switch message["type"] {
        case "chat":
            s.handleChatMessage(conn, message)
        case "ping":
            conn.WriteJSON(map[string]interface{}{
                "type": "pong",
                "time": time.Now().Unix(),
            })
        default:
            log.Printf("未知消息类型: %s", message["type"])
        }
    }

    // 客户端断开连接
    s.clientsMu.Lock()
    delete(s.clients, conn)
    clientCount = len(s.clients)
    s.clientsMu.Unlock()

    // 通知其他用户
    leaveMessage := map[string]interface{}{
        "type":    "user_leave",
        "message": fmt.Sprintf("用户离开,剩余用户: %d", clientCount),
        "time":    time.Now().Format("15:04:05"),
    }
    s.broadcastToAll(leaveMessage)

    log.Printf("客户端断开,剩余连接数: %d", clientCount)
}

func (s *WebSocketServer) handleChatMessage(conn *websocket.Conn, message map[string]interface{}) {
    // 验证消息
    msgText, ok := message["message"].(string)
    if !ok || utf8.RuneCountInString(msgText) > 1000 {
        conn.WriteJSON(map[string]interface{}{
            "type": "error",
            "message": "消息过长或格式错误",
        })
        return
    }

    // 广播消息给所有客户端
    broadcastMsg := map[string]interface{}{
        "type":      "chat",
        "username":  fmt.Sprintf("用户%d", time.Now().Unix()%1000), // 简单用户名
        "message":   msgText,
        "timestamp": time.Now().Format("15:04:05"),
    }

    s.broadcastToAll(broadcastMsg)
}

func (s *WebSocketServer) broadcastToAll(message map[string]interface{}) {
    data, err := json.Marshal(message)
    if err != nil {
        log.Printf("消息序列化错误: %v", err)
        return
    }

    s.broadcast <- data
}

func (s *WebSocketServer) broadcastMessages() {
    for message := range s.broadcast {
        s.clientsMu.RLock()
        for client := range s.clients {
            go func(conn *websocket.Conn) {
                err := conn.WriteMessage(websocket.TextMessage, message)
                if err != nil {
                    log.Printf("广播消息错误: %v", err)
                    conn.Close()
                    s.clientsMu.Lock()
                    delete(s.clients, conn)
                    s.clientsMu.Unlock()
                }
            }(client)
        }
        s.clientsMu.RUnlock()
    }
}

// 2. 实时数据推送服务器
type DataPushServer struct {
    upgrader websocket.Upgrader
    clients  map[*websocket.Conn]chan bool
    clientsMu sync.RWMutex
}

func NewDataPushServer() *DataPushServer {
    return &DataPushServer{
        upgrader: websocket.Upgrader{
            CheckOrigin: func(r *http.Request) bool { return true },
        },
        clients: make(map[*websocket.Conn]chan bool),
    }
}

func (s *DataPushServer) Start(port string) {
    http.HandleFunc("/data", s.handleDataWebSocket)
    http.HandleFunc("/", s.serveDataPage)

    fmt.Printf("实时数据服务器启动在 http://localhost%s\n", port)
    log.Fatal(http.ListenAndServe(port, nil))
}

func (s *DataPushServer) serveDataPage(w http.ResponseWriter, r *http.Request) {
    html := `
    <!DOCTYPE html>
    <html>
    <head>
        <title>实时数据监控</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            .metric { border: 1px solid #ddd; padding: 10px; margin: 10px 0; }
            .value { font-size: 24px; font-weight: bold; color: #0066cc; }
            .timestamp { color: #666; font-size: 12px; }
        </style>
    </head>
    <body>
        <h1>实时数据监控</h1>
        <div id="metrics"></div>

        <script>
            const ws = new WebSocket('ws://' + window.location.host + '/data');
            const metricsDiv = document.getElementById('metrics');

            ws.onmessage = function(event) {
                const data = JSON.parse(event.data);

                let metricDiv = document.getElementById('metric-' + data.name);
                if (!metricDiv) {
                    metricDiv = document.createElement('div');
                    metricDiv.className = 'metric';
                    metricDiv.id = 'metric-' + data.name;
                    metricDiv.innerHTML = `
                    `<h3>${data.name}</h3>
                    <div class="value">${data.value}</div>
                    <div class="timestamp">最后更新: <span id="time-${data.name}">${data.timestamp}</span></div>`;
                    metricsDiv.appendChild(metricDiv);
                } else {
                    metricDiv.querySelector('.value').textContent = data.value;
                    metricDiv.querySelector('.timestamp span').textContent = data.timestamp;
                }
            };
        </script>
    </body>
    </html>
    `
    w.Header().Set("Content-Type", "text/html; charset=utf-8")
    w.Write([]byte(html))
}

func (s *DataPushServer) handleDataWebSocket(w http.ResponseWriter, r *http.Request) {
    conn, err := s.upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }
    defer conn.Close()

    // 注册客户端
    stopChan := make(chan bool)
    s.clientsMu.Lock()
    s.clients[conn] = stopChan
    s.clientsMu.Unlock()

    log.Println("新的数据监控客户端连接")

    // 为客户端推送数据
    go s.pushDataToClient(conn, stopChan)

    // 等待连接关闭
    <-stopChan

    s.clientsMu.Lock()
    delete(s.clients, conn)
    s.clientsMu.Unlock()

    log.Println("数据监控客户端断开")
}

func (s *DataPushServer) pushDataToClient(conn *websocket.Conn, stopChan chan bool) {
    ticker := time.NewTicker(2 * time.Second)
    defer ticker.Stop()

    metrics := []string{"CPU使用率", "内存使用", "网络流量", "磁盘IO"}

    for {
        select {
        case <-ticker.C:
            // 生成模拟监控数据
            for _, metric := range metrics {
                data := map[string]interface{}{
                    "name":      metric,
                    "value":     fmt.Sprintf("%.1f%%", float64(time.Now().Unix()%100)+float64(time.Now().Nanosecond())/1e9),
                    "timestamp": time.Now().Format("15:04:05.000"),
                }

                if err := conn.WriteJSON(data); err != nil {
                    log.Printf("推送数据错误: %v", err)
                    stopChan <- true
                    return
                }
            }

        case <-stopChan:
            return
        }
    }
}

// 3. 高级聊天室 with 房间支持
type ChatRoom struct {
    name    string
    clients map[*websocket.Conn]*Client
    mu      sync.RWMutex
}

type Client struct {
    conn     *websocket.Conn
    username string
    room     *ChatRoom
}

type ChatServer struct {
    upgrader websocket.Upgrader
    rooms    map[string]*ChatRoom
    roomsMu  sync.RWMutex
}

func NewChatServer() *ChatServer {
    return &ChatServer{
        upgrader: websocket.Upgrader{
            CheckOrigin: func(r *http.Request) bool { return true },
        },
        rooms: make(map[string]*ChatRoom),
    }
}

func (s *ChatServer) Start(port string) {
    // 创建默认房间
    s.rooms["general"] = &ChatRoom{
        name:    "general",
        clients: make(map[*websocket.Conn]*Client),
    }
    s.rooms["tech"] = &ChatRoom{
        name:    "tech",
        clients: make(map[*websocket.Conn]*Client),
    }

    http.HandleFunc("/", s.serveChatHome)
    http.HandleFunc("/ws", s.handleChatWebSocket)

    fmt.Printf("高级聊天室启动在 http://localhost%s\n", port)
    log.Fatal(http.ListenAndServe(port, nil))
}

func (s *ChatServer) serveChatHome(w http.ResponseWriter, r *http.Request) {
    html := `
    <!DOCTYPE html>
    <html>
    <head>
        <title>高级聊天室</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            .container { display: flex; }
            .sidebar { width: 200px; margin-right: 20px; }
            .room { padding: 10px; margin: 5px 0; background: #f0f0f0; cursor: pointer; }
            .room.active { background: #0066cc; color: white; }
            .chat-area { flex: 1; }
            #messages { height: 400px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; }
            .message { margin: 5px 0; }
            .system { color: #666; font-style: italic; }
            .user { color: #0066cc; font-weight: bold; }
        </style>
    </head>
    <body>
        <h1>高级聊天室</h1>
        <div class="container">
            <div class="sidebar">
                <h3>房间列表</h3>
                <div class="room active" onclick="joinRoom('general')">综合聊天</div>
                <div class="room" onclick="joinRoom('tech')">技术讨论</div>
            </div>
            <div class="chat-area">
                <div id="messages"></div>
                <input type="text" id="messageInput" placeholder="输入消息..." style="width: 300px;">
                <button onclick="sendMessage()">发送</button>
                <button onclick="clearMessages()">清空</button>
            </div>
        </div>

        <script>
            let ws = null;
            let currentRoom = 'general';

            function joinRoom(roomName) {
                if (ws) {
                    ws.close();
                }

                currentRoom = roomName;
                document.querySelectorAll('.room').forEach(r => r.classList.remove('active'));
                event.target.classList.add('active');
                document.getElementById('messages').innerHTML = '';

                ws = new WebSocket('ws://' + window.location.host + '/ws?room=' + roomName);

                ws.onmessage = function(event) {
                    const data = JSON.parse(event.data);
                    const messagesDiv = document.getElementById('messages');

                    let messageDiv = document.createElement('div');
                    messageDiv.className = 'message';

                    if (data.type === 'system') {
                        messageDiv.innerHTML = '<span class="system">' + data.message + '</span>';
                    } else if (data.type === 'user_join') {
                        messageDiv.innerHTML = '<span class="system">' + data.message + '</span>';
                    } else if (data.type === 'user_leave') {
                        messageDiv.innerHTML = '<span class="system">' + data.message + '</span>';
                    } else {
                        messageDiv.innerHTML = '<span class="user">' + data.username + ':</span> ' + data.message;
                    }

                    messagesDiv.appendChild(messageDiv);
                    messagesDiv.scrollTop = messagesDiv.scrollHeight;
                };

                ws.onopen = function() {
                    ws.send(JSON.stringify({
                        type: 'join',
                        room: roomName,
                        username: '用户' + Math.floor(Math.random() * 1000)
                    }));
                };
            }

            function sendMessage() {
                if (ws && ws.readyState === WebSocket.OPEN) {
                    const input = document.getElementById('messageInput');
                    const message = input.value.trim();

                    if (message) {
                        ws.send(JSON.stringify({
                            type: 'chat',
                            message: message,
                            room: currentRoom
                        }));
                        input.value = '';
                    }
                }
            }

            function clearMessages() {
                document.getElementById('messages').innerHTML = '';
            }

            document.getElementById('messageInput').addEventListener('keypress', function(e) {
                if (e.key === 'Enter') {
                    sendMessage();
                }
            });

            // 初始加入默认房间
            joinRoom('general');
        </script>
    </body>
    </html>
    `
    w.Header().Set("Content-Type", "text/html; charset=utf-8")
    w.Write([]byte(html))
}

func (s *ChatServer) handleChatWebSocket(w http.ResponseWriter, r *http.Request) {
    conn, err := s.upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }
    defer conn.Close()

    var client *Client

    for {
        var message map[string]interface{}
        err := conn.ReadJSON(&message)
        if err != nil {
            log.Printf("读取消息错误: %v", err)
            break
        }

        switch message["type"] {
        case "join":
            roomName := message["room"].(string)
            username := message["username"].(string)

            client = s.joinRoom(conn, roomName, username)
            if client == nil {
                return
            }

        case "chat":
            if client != nil {
                s.broadcastToRoom(client.room, map[string]interface{}{
                    "type":     "chat",
                    "username": client.username,
                    "message":  message["message"],
                    "time":     time.Now().Format("15:04:05"),
                })
            }
        }
    }

    // 客户端断开
    if client != nil {
        s.leaveRoom(client)
    }
}

func (s *ChatServer) joinRoom(conn *websocket.Conn, roomName, username string) *Client {
    s.roomsMu.RLock()
    room, exists := s.rooms[roomName]
    s.roomsMu.RUnlock()

    if !exists {
        conn.WriteJSON(map[string]interface{}{
            "type":    "error",
            "message": "房间不存在",
        })
        return nil
    }

    client := &Client{
        conn:     conn,
        username: username,
        room:     room,
    }

    room.mu.Lock()
    room.clients[conn] = client
    clientCount := len(room.clients)
    room.mu.Unlock()

    // 通知房间有新用户加入
    s.broadcastToRoom(room, map[string]interface{}{
        "type":    "user_join",
        "message": fmt.Sprintf("%s 加入了房间 (在线: %d)", username, clientCount),
    })

    conn.WriteJSON(map[string]interface{}{
        "type":    "system",
        "message": fmt.Sprintf("欢迎加入 %s 房间!", roomName),
    })

    log.Printf("用户 %s 加入房间 %s", username, roomName)
    return client
}

func (s *ChatServer) leaveRoom(client *Client) {
    room := client.room

    room.mu.Lock()
    delete(room.clients, client.conn)
    clientCount := len(room.clients)
    room.mu.Unlock()

    s.broadcastToRoom(room, map[string]interface{}{
        "type":    "user_leave",
        "message": fmt.Sprintf("%s 离开了房间 (剩余: %d)", client.username, clientCount),
    })

    log.Printf("用户 %s 离开房间 %s", client.username, room.name)
}

func (s *ChatServer) broadcastToRoom(room *ChatRoom, message map[string]interface{}) {
    room.mu.RLock()
    defer room.mu.RUnlock()

    for _, client := range room.clients {
        go func(c *Client) {
            if err := c.conn.WriteJSON(message); err != nil {
                log.Printf("广播消息错误: %v", err)
            }
        }(client)
    }
}

func main() {
    if len(os.Args) < 2 {
        fmt.Println("WebSocket 示例用法:")
        fmt.Println("  基本聊天室: go run websocket.go chat")
        fmt.Println("  实时数据: go run websocket.go data")
        fmt.Println("  高级聊天室: go run websocket.go advanced")
        return
    }

    mode := os.Args[1]
    port := ":8080"
    if len(os.Args) > 2 {
        port = ":" + os.Args[2]
    }

    switch mode {
    case "chat":
        server := NewWebSocketServer()
        server.Start(port)
    case "data":
        server := NewDataPushServer()
        server.Start(port)
    case "advanced":
        server := NewChatServer()
        server.Start(port)
    default:
        fmt.Println("未知模式")
    }
}

WebSocket 关键特性

  1. ​ 核心概念 ​

​Upgrader​:将 HTTP 连接升级为 WebSocket

​ 消息类型 ​:TextMessage(文本)、BinaryMessage(二进制)

​ 连接管理 ​:客户端连接跟踪和广播

  1. ​ 重要方法 ​

Upgrade()- 升级 HTTP 连接到 WebSocket

ReadJSON()/WriteJSON()- JSON 消息读写

WriteMessage()- 原始消息写入

Close()- 关闭连接

  1. ​ 最佳实践 ​

​ 错误处理 ​:处理连接断开和消息错误

​ 并发安全 ​:使用互斥锁保护共享数据

​ 心跳检测 ​:定期发送 ping/pong 保持连接

​ 消息限制 ​:限制消息大小和频率

​ 连接限制 ​:控制最大连接数

  1. ​ 适用场景 ​

实时聊天应用

多人协作工具

实时数据监控

在线游戏

股票行情推送

WebSocket 为 Go 应用程序提供了强大的实时通信能力,适合构建各种需要低延迟、双向通信的应用。

通关密语:ws