package handler import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "log/slog" "net/http" "net/url" "sort" "strconv" "strings" "github.com/denisovdennis/autohero/internal/storage" ) // contextKey is an unexported type for context keys in this package. type contextKey string const telegramIDKey contextKey = "telegram_id" // TelegramIDFromContext extracts the Telegram user ID from the request context. func TelegramIDFromContext(ctx context.Context) (int64, bool) { id, ok := ctx.Value(telegramIDKey).(int64) return id, ok } // AuthHandler handles Telegram authentication. type AuthHandler struct { botToken string store *storage.HeroStore logger *slog.Logger } // NewAuthHandler creates a new auth handler. func NewAuthHandler(botToken string, store *storage.HeroStore, logger *slog.Logger) *AuthHandler { return &AuthHandler{ botToken: botToken, store: store, logger: logger, } } // TelegramAuth validates Telegram initData, creates hero if first time, returns hero ID. // POST /api/v1/auth/telegram func (h *AuthHandler) TelegramAuth(w http.ResponseWriter, r *http.Request) { initData := r.Header.Get("X-Telegram-Init-Data") if initData == "" { initData = r.URL.Query().Get("initData") } if initData == "" { writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "missing initData", }) return } telegramID, err := validateInitData(initData, h.botToken) if err != nil { h.logger.Warn("telegram auth failed", "error", err) writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "invalid initData: " + err.Error(), }) return } hero, err := h.store.GetOrCreate(r.Context(), telegramID, "Hero") if err != nil { h.logger.Error("failed to get or create hero", "telegram_id", telegramID, "error", err) writeJSON(w, http.StatusInternalServerError, map[string]string{ "error": "failed to load hero", }) return } h.logger.Info("telegram auth success", "telegram_id", telegramID, "hero_id", hero.ID) writeJSON(w, http.StatusOK, map[string]any{ "heroId": hero.ID, "hero": hero, }) } // TelegramAuthMiddleware validates the Telegram initData on every request // and injects the telegram_id into the request context. func TelegramAuthMiddleware(botToken string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { initData := r.Header.Get("X-Telegram-Init-Data") if initData == "" { initData = r.URL.Query().Get("initData") } if initData == "" { writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "missing initData", }) return } telegramID, err := validateInitData(initData, botToken) if err != nil { writeJSON(w, http.StatusUnauthorized, map[string]string{ "error": "invalid initData", }) return } ctx := context.WithValue(r.Context(), telegramIDKey, telegramID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // validateInitData parses and validates Telegram Web App initData. // Returns the telegram user ID on success. func validateInitData(initData string, botToken string) (int64, error) { if botToken == "" { // In dev mode without a bot token, try to parse the user ID anyway. return parseUserIDFromInitData(initData) } values, err := url.ParseQuery(initData) if err != nil { return 0, fmt.Errorf("parse initData: %w", err) } hash := values.Get("hash") if hash == "" { return 0, fmt.Errorf("missing hash in initData") } // Build the data-check-string: sort all key=value pairs except "hash", join with \n. var pairs []string for k, v := range values { if k == "hash" { continue } pairs = append(pairs, k+"="+v[0]) } sort.Strings(pairs) dataCheckString := strings.Join(pairs, "\n") // Derive the secret key: HMAC-SHA256 of bot token with key "WebAppData". secretKeyMac := hmac.New(sha256.New, []byte("WebAppData")) secretKeyMac.Write([]byte(botToken)) secretKey := secretKeyMac.Sum(nil) // Compute HMAC-SHA256 of the data-check-string with the secret key. mac := hmac.New(sha256.New, secretKey) mac.Write([]byte(dataCheckString)) computedHash := hex.EncodeToString(mac.Sum(nil)) if !hmac.Equal([]byte(computedHash), []byte(hash)) { return 0, fmt.Errorf("hash mismatch") } return parseUserIDFromInitData(initData) } // parseUserIDFromInitData extracts the user ID from the initData query string. func parseUserIDFromInitData(initData string) (int64, error) { values, err := url.ParseQuery(initData) if err != nil { return 0, fmt.Errorf("parse initData: %w", err) } // The user field is a JSON string, but we just need the id. // Format: user={"id":123456789,"first_name":"Name",...} userStr := values.Get("user") if userStr == "" { // Try query_id based format or direct id parameter. idStr := values.Get("id") if idStr == "" { return 0, fmt.Errorf("no user data in initData") } return strconv.ParseInt(idStr, 10, 64) } // Simple extraction of "id" from JSON without full unmarshal. // Find "id": followed by a number. idx := strings.Index(userStr, `"id":`) if idx == -1 { return 0, fmt.Errorf("no id field in user data") } rest := userStr[idx+5:] // Trim any whitespace. rest = strings.TrimSpace(rest) // Read digits until non-digit. var numStr string for _, c := range rest { if c >= '0' && c <= '9' { numStr += string(c) } else { break } } if numStr == "" { return 0, fmt.Errorf("invalid id in user data") } return strconv.ParseInt(numStr, 10, 64) }