Go TCP

CP(Transmission Control Protocol)是一种面向连接的、可靠的、基于字节流的传输层通信协议。Go 语言在标准库 net 包中提供了完善的 TCP 网络编程支持。

下面是一个示例:

package main

import (
    "bufio"
    "encoding/json"
    "fmt"
    "io"
    "log"
    "net"
    "os"
    "strconv"
    "strings"
    "sync"
    "time"
)

// 1. 基本 TCP 服务器和客户端
func basicTCPServer(port string) {
    fmt.Printf("启动基本 TCP 服务器 :%s\n", port)

    listener, err := net.Listen("tcp", ":"+port)
    if err != nil {
        log.Fatalf("服务器启动失败: %v", err)
    }
    defer listener.Close()

    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Printf("接受连接失败: %v", err)
            continue
        }

        go handleBasicConnection(conn)
    }
}

func handleBasicConnection(conn net.Conn) {
    defer conn.Close()

    remoteAddr := conn.RemoteAddr().String()
    fmt.Printf("客户端连接: %s\n", remoteAddr)

    // 设置读取超时
    conn.SetReadDeadline(time.Now().Add(30 * time.Second))

    // 读取客户端数据
    reader := bufio.NewReader(conn)
    for {
        message, err := reader.ReadString('\n')
        if err != nil {
            if err == io.EOF {
                fmt.Printf("客户端断开连接: %s\n", remoteAddr)
            } else {
                log.Printf("读取错误: %v", err)
            }
            return
        }

        message = strings.TrimSpace(message)
        fmt.Printf("收到来自 %s 的消息: %s\n", remoteAddr, message)

        if message == "quit" {
            conn.Write([]byte("再见!\n"))
            return
        }

        // 回显消息
        response := fmt.Sprintf("服务器回复: %s\n", strings.ToUpper(message))
        conn.Write([]byte(response))
    }
}

func basicTCPClient(serverAddr string) {
    fmt.Printf("连接 TCP 服务器: %s\n", serverAddr)

    conn, err := net.Dial("tcp", serverAddr)
    if err != nil {
        log.Fatalf("连接服务器失败: %v", err)
    }
    defer conn.Close()

    go receiveMessages(conn)

    // 发送消息
    scanner := bufio.NewScanner(os.Stdin)
    fmt.Println("输入消息 (输入 'quit' 退出):")

    for scanner.Scan() {
        message := scanner.Text()

        _, err := conn.Write([]byte(message + "\n"))
        if err != nil {
            log.Printf("发送失败: %v", err)
            break
        }

        if message == "quit" {
            break
        }
    }
}

func receiveMessages(conn net.Conn) {
    reader := bufio.NewReader(conn)
    for {
        message, err := reader.ReadString('\n')
        if err != nil {
            if err == io.EOF {
                fmt.Println("服务器断开连接")
            } else {
                log.Printf("接收错误: %v", err)
            }
            return
        }
        fmt.Printf("服务器回复: %s", message)
    }
}

// 2. 并发 TCP 服务器 with 连接管理
type TCPServer struct {
    listener   net.Listener
    clients    map[net.Conn]string
    clientsMux sync.RWMutex
    wg         sync.WaitGroup
}

func NewTCPServer(port string) *TCPServer {
    listener, err := net.Listen("tcp", ":"+port)
    if err != nil {
        log.Fatal(err)
    }

    return &TCPServer{
        listener: listener,
        clients:  make(map[net.Conn]string),
    }
}

func (s *TCPServer) Start() {
    fmt.Printf("并发 TCP 服务器启动 :%s\n", s.listener.Addr().(*net.TCPAddr).Port)

    go s.acceptConnections()

    // 等待中断信号
    <-make(chan os.Signal, 1)
    s.Stop()
}

func (s *TCPServer) acceptConnections() {
    for {
        conn, err := s.listener.Accept()
        if err != nil {
            log.Printf("接受连接错误: %v", err)
            continue
        }

        s.wg.Add(1)
        go s.handleClient(conn)
    }
}

func (s *TCPServer) handleClient(conn net.Conn) {
    defer s.wg.Done()
    defer conn.Close()

    clientID := conn.RemoteAddr().String()

    // 注册客户端
    s.clientsMux.Lock()
    s.clients[conn] = clientID
    s.clientsMux.Unlock()

    // 清理客户端
    defer func() {
        s.clientsMux.Lock()
        delete(s.clients, clientID)
        s.clientsMux.Unlock()
        fmt.Printf("客户端断开: %s\n", clientID)
    }()

    fmt.Printf("新客户端连接: %s, 当前客户端数: %d\n", clientID, len(s.clients))

    // 发送欢迎消息
    welcome := fmt.Sprintf("欢迎!你的客户端ID: %s\n", clientID)
    conn.Write([]byte(welcome))

    reader := bufio.NewReader(conn)
    for {
        conn.SetReadDeadline(time.Now().Add(5 * time.Minute))

        message, err := reader.ReadString('\n')
        if err != nil {
            if err == io.EOF {
                fmt.Printf("客户端主动断开: %s\n", clientID)
            } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
                fmt.Printf("客户端超时: %s\n", clientID)
            } else {
                log.Printf("读取错误: %v", err)
            }
            return
        }

        message = strings.TrimSpace(message)
        fmt.Printf("来自 %s: %s\n", clientID, message)

        // 处理命令
        if strings.HasPrefix(message, "/") {
            s.handleCommand(conn, clientID, message)
        } else {
            s.broadcastMessage(clientID, message)
        }
    }
}

func (s *TCPServer) handleCommand(conn net.Conn, clientID, command string) {
    parts := strings.Fields(command)
    if len(parts) == 0 {
        return
    }

    switch parts[0] {
    case "/time":
        conn.Write([]byte(fmt.Sprintf("服务器时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))))
    case "/clients":
        s.clientsMux.RLock()
        count := len(s.clients)
        conn.Write([]byte(fmt.Sprintf("在线客户端数: %d\n", count)))
        s.clientsMux.RUnlock()
    case "/whisper":
        if len(parts) >= 3 {
            // 简化版私聊
            conn.Write([]byte("私聊功能开发中...\n"))
        }
    default:
        conn.Write([]byte("未知命令。可用命令: /time, /clients, /whisper\n"))
    }
}

func (s *TCPServer) broadcastMessage(senderID, message string) {
    s.clientsMux.RLock()
    defer s.clientsMux.RUnlock()

    broadcastMsg := fmt.Sprintf("[%s]: %s\n", senderID, message)

    for conn := range s.clients {
        if conn.RemoteAddr().String() != senderID { // 不发给发送者自己
            go func(c net.Conn) {
                _, err := c.Write([]byte(broadcastMsg))
                if err != nil {
                    log.Printf("广播失败: %v", err)
                }
            }(conn)
        }
    }
}

func (s *TCPServer) Stop() {
    fmt.Println("正在关闭服务器...")

    // 通知所有客户端
    s.clientsMux.RLock()
    for conn := range s.clients {
        conn.Write([]byte("服务器即将关闭\n"))
        conn.Close()
    }
    s.clientsMux.RUnlock()

    s.listener.Close()
    s.wg.Wait()
    fmt.Println("服务器已关闭")
}

// 3. 协议化的 TCP 通信(JSON 格式)
type Message struct {
    Type    string      `json:"type"`
    From    string      `json:"from,omitempty"`
    To      string      `json:"to,omitempty"`
    Content interface{} `json:"content"`
    Time    int64       `json:"time"`
}

type ProtocolTCPServer struct {
    listener net.Listener
    clients  sync.Map // 使用 sync.Map 替代 mutex
}

func NewProtocolTCPServer(port string) *ProtocolTCPServer {
    listener, err := net.Listen("tcp", ":"+port)
    if err != nil {
        log.Fatal(err)
    }

    return &ProtocolTCPServer{
        listener: listener,
    }
}

func (s *ProtocolTCPServer) Start() {
    fmt.Printf("协议化 TCP 服务器启动 :%s\n", s.listener.Addr().(*net.TCPAddr).Port)

    for {
        conn, err := s.listener.Accept()
        if err != nil {
            log.Printf("接受连接错误: %v", err)
            continue
        }

        go s.handleProtocolClient(conn)
    }
}

func (s *ProtocolTCPServer) handleProtocolClient(conn net.Conn) {
    defer conn.Close()

    clientID := conn.RemoteAddr().String()
    s.clients.Store(conn, clientID)
    defer s.clients.Delete(conn)

    fmt.Printf("协议客户端连接: %s\n", clientID)

    // 发送欢迎消息
    welcomeMsg := Message{
        Type:    "welcome",
        Content: fmt.Sprintf("连接成功,你的ID: %s", clientID),
        Time:    time.Now().Unix(),
    }
    s.sendMessage(conn, welcomeMsg)

    decoder := json.NewDecoder(conn)
    for {
        var msg Message
        if err := decoder.Decode(&msg); err != nil {
            if err == io.EOF {
                fmt.Printf("客户端断开: %s\n", clientID)
            } else {
                log.Printf("协议解析错误: %v", err)
            }
            return
        }

        msg.From = clientID
        msg.Time = time.Now().Unix()

        fmt.Printf("收到协议消息: %+v\n", msg)

        // 处理不同类型的消息
        switch msg.Type {
        case "chat":
            s.broadcastProtocolMessage(msg)
        case "ping":
            s.sendMessage(conn, Message{Type: "pong", Content: "pong"})
        case "info":
            var count int
            s.clients.Range(func(key, value interface{}) bool {
                count++
                return true
            })
            s.sendMessage(conn, Message{Type: "info", Content: fmt.Sprintf("在线客户端: %d", count)})
        default:
            s.sendMessage(conn, Message{Type: "error", Content: "未知消息类型"})
        }
    }
}

func (s *ProtocolTCPServer) broadcastProtocolMessage(msg Message) {
    s.clients.Range(func(key, value interface{}) bool {
        conn := key.(net.Conn)
        if conn.RemoteAddr().String() != msg.From { // 不发给发送者
            go s.sendMessage(conn, msg)
        }
        return true
    })
}

func (s *ProtocolTCPServer) sendMessage(conn net.Conn, msg Message) error {
    data, err := json.Marshal(msg)
    if err != nil {
        return err
    }
    data = append(data, '\n') // 添加分隔符
    _, err = conn.Write(data)
    return err
}

// 4. 协议化 TCP 客户端
func protocolTCPClient(serverAddr string) {
    conn, err := net.Dial("tcp", serverAddr)
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()

    fmt.Printf("连接到协议服务器: %s\n", serverAddr)

    // 接收消息的 goroutine
    go func() {
        decoder := json.NewDecoder(conn)
        for {
            var msg Message
            if err := decoder.Decode(&msg); err != nil {
                if err == io.EOF {
                    fmt.Println("服务器断开连接")
                } else {
                    log.Printf("接收错误: %v", err)
                }
                return
            }
            fmt.Printf("服务器消息 [%s]: %v\n", msg.Type, msg.Content)
        }
    }()

    // 发送消息
    scanner := bufio.NewScanner(os.Stdin)
    fmt.Println("输入消息类型 (chat/ping/info) 和内容:")

    for scanner.Scan() {
        input := strings.TrimSpace(scanner.Text())
        if input == "quit" {
            break
        }

        parts := strings.SplitN(input, " ", 2)
        msgType := "chat"
        content := input

        if len(parts) >= 1 {
            msgType = parts[0]
            if len(parts) > 1 {
                content = parts[1]
            }
        }

        msg := Message{
            Type:    msgType,
            Content: content,
            Time:    time.Now().Unix(),
        }

        data, err := json.Marshal(msg)
        if err != nil {
            log.Printf("序列化错误: %v", err)
            continue
        }
        data = append(data, '\n')

        if _, err := conn.Write(data); err != nil {
            log.Printf("发送错误: %v", err)
            break
        }
    }
}

// 5. TCP 文件传输示例
func fileTransferServer(port string) {
    listener, err := net.Listen("tcp", ":"+port)
    if err != nil {
        log.Fatal(err)
    }
    defer listener.Close()

    fmt.Printf("文件传输服务器启动 :%s\n", port)

    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Printf("接受连接错误: %v", err)
            continue
        }

        go handleFileTransfer(conn)
    }
}

func handleFileTransfer(conn net.Conn) {
    defer conn.Close()

    // 读取文件名和大小
    reader := bufio.NewReader(conn)
    filename, err := reader.ReadString('\n')
    if err != nil {
        log.Printf("读取文件名错误: %v", err)
        return
    }
    filename = strings.TrimSpace(filename)

    sizeStr, err := reader.ReadString('\n')
    if err != nil {
        log.Printf("读取文件大小错误: %v", err)
        return
    }
    sizeStr = strings.TrimSpace(sizeStr)
    fileSize, err := strconv.ParseInt(sizeStr, 10, 64)
    if err != nil {
        log.Printf("解析文件大小错误: %v", err)
        return
    }

    fmt.Printf("接收文件: %s, 大小: %d bytes\n", filename, fileSize)

    // 创建文件
    file, err := os.Create("received_" + filename)
    if err != nil {
        log.Printf("创建文件错误: %v", err)
        conn.Write([]byte("ERROR: 无法创建文件\n"))
        return
    }
    defer file.Close()

    // 接收文件数据
    bytesReceived, err := io.CopyN(file, reader, fileSize)
    if err != nil {
        log.Printf("接收文件错误: %v", err)
        return
    }

    fmt.Printf("文件接收完成: %s, 接收字节: %d\n", filename, bytesReceived)
    conn.Write([]byte("SUCCESS: 文件接收成功\n"))
}

func main() {
    if len(os.Args) < 2 {
        fmt.Println("用法:")
        fmt.Println("  服务器: go run tcp.go server [port]")
        fmt.Println("  基本客户端: go run tcp.go client [host:port]")
        fmt.Println("  并发服务器: go run tcp.go concurrent [port]")
        fmt.Println("  协议客户端: go run tcp.go protocol [host:port]")
        return
    }

    mode := os.Args[1]

    switch mode {
    case "server":
        port := "8080"
        if len(os.Args) > 2 {
            port = os.Args[2]
        }
        basicTCPServer(port)

    case "client":
        addr := "localhost:8080"
        if len(os.Args) > 2 {
            addr = os.Args[2]
        }
        basicTCPClient(addr)

    case "concurrent":
        port := "8081"
        if len(os.Args) > 2 {
            port = os.Args[2]
        }
        server := NewTCPServer(port)
        server.Start()

    case "protocol":
        addr := "localhost:8082"
        if len(os.Args) > 2 {
            addr = os.Args[2]
        }
        protocolTCPClient(addr)

    default:
        fmt.Println("未知模式")
    }
}

TCP 编程关键要点

  1. ​ 核心组件 ​

net.Listen()- 创建监听器

Listener.Accept()- 接受连接

net.Dial()- 建立连接

Conn.Read()/Conn.Write()- 数据读写

  1. ​ 重要特性 ​

​ 并发处理 ​:每个连接使用 goroutine

​ 超时控制 ​:SetDeadline(), SetReadDeadline(), SetWriteDeadline()

​ 连接管理 ​:记录活跃连接,优雅关闭

​ 错误处理 ​:网络错误、超时、断开连接

  1. ​ 最佳实践 ​

​ 总是处理错误 ​:检查每个网络操作的错误

​ 使用 goroutine​:为每个连接创建独立的 goroutine

​ 资源清理 ​:使用 defer 关闭连接和文件

​ 超时设置 ​:防止连接长时间挂起

​ 协议设计 ​:定义清晰的数据格式(如 JSON)

Go 的 TCP 编程非常强大且简洁,适合构建各种网络应用,从简单的客户端/服务器到复杂的分布式系统。

通关密语:tcp