You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

353 lines
8.7 KiB
Go

package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/denisovdennis/autohero/internal/model"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// TODO: restrict origins in production.
return true
},
}
// Hub maintains active WebSocket connections, broadcasts envelopes,
// and routes incoming client messages.
type Hub struct {
clients map[*Client]bool
register chan *Client
unregister chan *Client
broadcast chan model.WSEnvelope
Incoming chan model.ClientMessage // inbound commands from clients
mu sync.RWMutex
logger *slog.Logger
// OnConnect is called when a client finishes registration.
// Set by the engine to push initial state. May be nil.
OnConnect func(heroID int64)
// OnDisconnect is called when a client is unregistered.
// remainingSameHero is how many other WS clients for this hero are still connected.
// Set by the engine to persist state; may be nil.
OnDisconnect func(heroID int64, remainingSameHero int)
}
// Client represents a single WebSocket connection.
type Client struct {
hub *Hub
conn *websocket.Conn
send chan model.WSEnvelope
heroID int64
}
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxMessageSize = 4096
sendBufSize = 64
)
// NewHub creates a new WebSocket hub.
func NewHub(logger *slog.Logger) *Hub {
return &Hub{
clients: make(map[*Client]bool),
register: make(chan *Client),
unregister: make(chan *Client),
broadcast: make(chan model.WSEnvelope, 256),
Incoming: make(chan model.ClientMessage, 256),
logger: logger,
}
}
// Run starts the hub's event loop. Should be called as a goroutine.
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
h.mu.Unlock()
h.logger.Info("client connected", "hero_id", client.heroID)
// Notify engine of new connection (sends hero_state, route, etc.).
if h.OnConnect != nil {
go h.OnConnect(client.heroID)
}
case client := <-h.unregister:
heroID := client.heroID
h.mu.Lock()
existed := false
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
existed = true
close(client.send)
}
remaining := 0
for c := range h.clients {
if c.heroID == heroID {
remaining++
}
}
h.mu.Unlock()
h.logger.Info("client disconnected", "hero_id", heroID, "remaining_same_hero", remaining)
// Always persist; engine drops in-memory movement only when remaining == 0.
// Synchronous so a reconnect that loads from DB sees the latest save.
if existed && h.OnDisconnect != nil {
h.OnDisconnect(heroID, remaining)
}
case env := <-h.broadcast:
h.mu.RLock()
for client := range h.clients {
select {
case client.send <- env:
default:
go func(c *Client) {
h.unregister <- c
}(client)
}
}
h.mu.RUnlock()
}
}
}
// BroadcastEvent wraps a legacy CombatEvent in an envelope and broadcasts
// it to the hero's connections. This maintains backward compatibility during migration.
func (h *Hub) BroadcastEvent(event model.CombatEvent) {
h.SendToHero(event.HeroID, event.Type, event)
}
// SendToHero sends a typed message to all WebSocket connections for a specific hero.
func (h *Hub) SendToHero(heroID int64, msgType string, payload any) {
env := model.NewWSEnvelope(msgType, payload)
h.mu.RLock()
defer h.mu.RUnlock()
for client := range h.clients {
if client.heroID == heroID {
select {
case client.send <- env:
default:
// Slow consumer, schedule disconnect.
go func(c *Client) {
h.unregister <- c
}(client)
}
}
}
}
// BroadcastAll sends an envelope to every connected client (rare: server announcements).
func (h *Hub) BroadcastAll(msgType string, payload any) {
env := model.NewWSEnvelope(msgType, payload)
select {
case h.broadcast <- env:
default:
h.logger.Warn("broadcast channel full, dropping event", "type", msgType)
}
}
// ConnectionCount returns the number of active WebSocket connections.
func (h *Hub) ConnectionCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
// ConnectedHeroIDs returns the hero IDs that have active WebSocket connections.
func (h *Hub) ConnectedHeroIDs() []int64 {
h.mu.RLock()
defer h.mu.RUnlock()
seen := make(map[int64]struct{}, len(h.clients))
for c := range h.clients {
seen[c.heroID] = struct{}{}
}
ids := make([]int64, 0, len(seen))
for id := range seen {
ids = append(ids, id)
}
return ids
}
// IsHeroConnected returns true if the given hero has at least one active WS connection.
func (h *Hub) IsHeroConnected(heroID int64) bool {
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
if c.heroID == heroID {
return true
}
}
return false
}
// WSHandler handles WebSocket upgrade requests.
type WSHandler struct {
hub *Hub
heroStore heroStoreLookup
logger *slog.Logger
}
// heroStoreLookup is a minimal interface to avoid import cycle with storage package.
type heroStoreLookup interface {
GetHeroIDByTelegramID(ctx context.Context, telegramID int64) (int64, error)
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(hub *Hub, heroStore heroStoreLookup, logger *slog.Logger) *WSHandler {
return &WSHandler{hub: hub, heroStore: heroStore, logger: logger}
}
// HandleWS upgrades the HTTP connection to WebSocket.
// GET /ws
func (h *WSHandler) HandleWS(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
// Resolve hero from Telegram initData or telegramId query param.
var heroID int64
initData := r.URL.Query().Get("initData")
if initData != "" {
if tid, err := parseUserIDFromInitData(initData); err == nil {
heroID = tid
}
}
if heroID == 0 {
// Dev fallback: accept telegramId query param.
if tidStr := r.URL.Query().Get("telegramId"); tidStr != "" {
if tid, err := strconv.ParseInt(tidStr, 10, 64); err == nil {
heroID = tid
}
}
}
if heroID == 0 {
heroID = 1 // last-resort fallback for dev
}
// heroID at this point is the Telegram user ID. Resolve to DB hero ID.
if h.heroStore != nil {
if dbID, err := h.heroStore.GetHeroIDByTelegramID(r.Context(), heroID); err == nil && dbID > 0 {
heroID = dbID
} else {
h.logger.Warn("ws: could not resolve telegram ID to hero ID", "telegram_id", heroID, "error", err)
}
}
client := &Client{
hub: h.hub,
conn: conn,
send: make(chan model.WSEnvelope, sendBufSize),
heroID: heroID,
}
h.hub.register <- client
go client.writePump()
go client.readPump()
}
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, msg, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.hub.logger.Warn("websocket read error", "error", err)
}
break
}
raw := string(msg)
// Backward compat: plain "ping" string.
if raw == "ping" {
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
_ = c.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong","payload":{}}`))
continue
}
// Parse as JSON envelope.
var env model.WSEnvelope
if err := json.Unmarshal(msg, &env); err != nil {
c.hub.logger.Debug("invalid ws message", "error", err, "hero_id", c.heroID)
continue
}
// Handle ping envelope.
if env.Type == "ping" {
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
_ = c.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong","payload":{}}`))
continue
}
// Route to hub's incoming channel for the engine to process.
select {
case c.hub.Incoming <- model.ClientMessage{
HeroID: c.heroID,
Type: env.Type,
Payload: env.Payload,
}:
default:
c.hub.logger.Warn("incoming channel full, dropping client message",
"type", env.Type, "hero_id", c.heroID)
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case env, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteJSON(env); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}