package handler import ( "context" "encoding/json" "log/slog" "net/http" "strconv" "sync" "time" "github.com/gorilla/websocket" "github.com/denisovdennis/autohero/internal/model" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { // TODO: restrict origins in production. return true }, } // Hub maintains active WebSocket connections, broadcasts envelopes, // and routes incoming client messages. type Hub struct { clients map[*Client]bool register chan *Client unregister chan *Client broadcast chan model.WSEnvelope Incoming chan model.ClientMessage // inbound commands from clients mu sync.RWMutex logger *slog.Logger // OnConnect is called when a client finishes registration. // Set by the engine to push initial state. May be nil. OnConnect func(heroID int64) // OnDisconnect is called when a client is unregistered. // Set by the engine to persist state and remove movement. May be nil. OnDisconnect func(heroID int64) } // Client represents a single WebSocket connection. type Client struct { hub *Hub conn *websocket.Conn send chan model.WSEnvelope heroID int64 } const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 4096 sendBufSize = 64 ) // NewHub creates a new WebSocket hub. func NewHub(logger *slog.Logger) *Hub { return &Hub{ clients: make(map[*Client]bool), register: make(chan *Client), unregister: make(chan *Client), broadcast: make(chan model.WSEnvelope, 256), Incoming: make(chan model.ClientMessage, 256), logger: logger, } } // Run starts the hub's event loop. Should be called as a goroutine. func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true h.mu.Unlock() h.logger.Info("client connected", "hero_id", client.heroID) // Notify engine of new connection (sends hero_state, route, etc.). if h.OnConnect != nil { go h.OnConnect(client.heroID) } case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.send) } h.mu.Unlock() h.logger.Info("client disconnected", "hero_id", client.heroID) // Notify engine of disconnection. if h.OnDisconnect != nil { go h.OnDisconnect(client.heroID) } case env := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.send <- env: default: go func(c *Client) { h.unregister <- c }(client) } } h.mu.RUnlock() } } } // BroadcastEvent wraps a legacy CombatEvent in an envelope and broadcasts // it to the hero's connections. This maintains backward compatibility during migration. func (h *Hub) BroadcastEvent(event model.CombatEvent) { h.SendToHero(event.HeroID, event.Type, event) } // SendToHero sends a typed message to all WebSocket connections for a specific hero. func (h *Hub) SendToHero(heroID int64, msgType string, payload any) { env := model.NewWSEnvelope(msgType, payload) h.mu.RLock() defer h.mu.RUnlock() for client := range h.clients { if client.heroID == heroID { select { case client.send <- env: default: // Slow consumer, schedule disconnect. go func(c *Client) { h.unregister <- c }(client) } } } } // BroadcastAll sends an envelope to every connected client (rare: server announcements). func (h *Hub) BroadcastAll(msgType string, payload any) { env := model.NewWSEnvelope(msgType, payload) select { case h.broadcast <- env: default: h.logger.Warn("broadcast channel full, dropping event", "type", msgType) } } // ConnectionCount returns the number of active WebSocket connections. func (h *Hub) ConnectionCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // ConnectedHeroIDs returns the hero IDs that have active WebSocket connections. func (h *Hub) ConnectedHeroIDs() []int64 { h.mu.RLock() defer h.mu.RUnlock() seen := make(map[int64]struct{}, len(h.clients)) for c := range h.clients { seen[c.heroID] = struct{}{} } ids := make([]int64, 0, len(seen)) for id := range seen { ids = append(ids, id) } return ids } // IsHeroConnected returns true if the given hero has at least one active WS connection. func (h *Hub) IsHeroConnected(heroID int64) bool { h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.heroID == heroID { return true } } return false } // WSHandler handles WebSocket upgrade requests. type WSHandler struct { hub *Hub heroStore heroStoreLookup logger *slog.Logger } // heroStoreLookup is a minimal interface to avoid import cycle with storage package. type heroStoreLookup interface { GetHeroIDByTelegramID(ctx context.Context, telegramID int64) (int64, error) } // NewWSHandler creates a new WebSocket handler. func NewWSHandler(hub *Hub, heroStore heroStoreLookup, logger *slog.Logger) *WSHandler { return &WSHandler{hub: hub, heroStore: heroStore, logger: logger} } // HandleWS upgrades the HTTP connection to WebSocket. // GET /ws func (h *WSHandler) HandleWS(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } // Resolve hero from Telegram initData or telegramId query param. var heroID int64 initData := r.URL.Query().Get("initData") if initData != "" { if tid, err := parseUserIDFromInitData(initData); err == nil { heroID = tid } } if heroID == 0 { // Dev fallback: accept telegramId query param. if tidStr := r.URL.Query().Get("telegramId"); tidStr != "" { if tid, err := strconv.ParseInt(tidStr, 10, 64); err == nil { heroID = tid } } } if heroID == 0 { heroID = 1 // last-resort fallback for dev } // heroID at this point is the Telegram user ID. Resolve to DB hero ID. if h.heroStore != nil { if dbID, err := h.heroStore.GetHeroIDByTelegramID(r.Context(), heroID); err == nil && dbID > 0 { heroID = dbID } else { h.logger.Warn("ws: could not resolve telegram ID to hero ID", "telegram_id", heroID, "error", err) } } client := &Client{ hub: h.hub, conn: conn, send: make(chan model.WSEnvelope, sendBufSize), heroID: heroID, } h.hub.register <- client go client.writePump() go client.readPump() } func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, msg, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { c.hub.logger.Warn("websocket read error", "error", err) } break } raw := string(msg) // Backward compat: plain "ping" string. if raw == "ping" { c.conn.SetWriteDeadline(time.Now().Add(writeWait)) _ = c.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong","payload":{}}`)) continue } // Parse as JSON envelope. var env model.WSEnvelope if err := json.Unmarshal(msg, &env); err != nil { c.hub.logger.Debug("invalid ws message", "error", err, "hero_id", c.heroID) continue } // Handle ping envelope. if env.Type == "ping" { c.conn.SetWriteDeadline(time.Now().Add(writeWait)) _ = c.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong","payload":{}}`)) continue } // Route to hub's incoming channel for the engine to process. select { case c.hub.Incoming <- model.ClientMessage{ HeroID: c.heroID, Type: env.Type, Payload: env.Payload, }: default: c.hub.logger.Warn("incoming channel full, dropping client message", "type", env.Type, "hero_id", c.heroID) } } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case env, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteJSON(env); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } }