You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
5.5 KiB
Go

package handler
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"github.com/denisovdennis/autohero/internal/storage"
)
// contextKey is an unexported type for context keys in this package.
type contextKey string
const telegramIDKey contextKey = "telegram_id"
// TelegramIDFromContext extracts the Telegram user ID from the request context.
func TelegramIDFromContext(ctx context.Context) (int64, bool) {
id, ok := ctx.Value(telegramIDKey).(int64)
return id, ok
}
// AuthHandler handles Telegram authentication.
type AuthHandler struct {
botToken string
store *storage.HeroStore
logger *slog.Logger
}
// NewAuthHandler creates a new auth handler.
func NewAuthHandler(botToken string, store *storage.HeroStore, logger *slog.Logger) *AuthHandler {
return &AuthHandler{
botToken: botToken,
store: store,
logger: logger,
}
}
// TelegramAuth validates Telegram initData, creates hero if first time, returns hero ID.
// POST /api/v1/auth/telegram
func (h *AuthHandler) TelegramAuth(w http.ResponseWriter, r *http.Request) {
initData := r.Header.Get("X-Telegram-Init-Data")
if initData == "" {
initData = r.URL.Query().Get("initData")
}
if initData == "" {
writeJSON(w, http.StatusUnauthorized, map[string]string{
"error": "missing initData",
})
return
}
telegramID, err := validateInitData(initData, h.botToken)
if err != nil {
h.logger.Warn("telegram auth failed", "error", err)
writeJSON(w, http.StatusUnauthorized, map[string]string{
"error": "invalid initData: " + err.Error(),
})
return
}
hero, err := h.store.GetOrCreate(r.Context(), telegramID, "Hero")
if err != nil {
h.logger.Error("failed to get or create hero", "telegram_id", telegramID, "error", err)
writeJSON(w, http.StatusInternalServerError, map[string]string{
"error": "failed to load hero",
})
return
}
h.logger.Info("telegram auth success", "telegram_id", telegramID, "hero_id", hero.ID)
writeJSON(w, http.StatusOK, map[string]any{
"heroId": hero.ID,
"hero": hero,
})
}
// TelegramAuthMiddleware validates the Telegram initData on every request
// and injects the telegram_id into the request context.
func TelegramAuthMiddleware(botToken string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
initData := r.Header.Get("X-Telegram-Init-Data")
if initData == "" {
initData = r.URL.Query().Get("initData")
}
if initData == "" {
writeJSON(w, http.StatusUnauthorized, map[string]string{
"error": "missing initData",
})
return
}
telegramID, err := validateInitData(initData, botToken)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]string{
"error": "invalid initData",
})
return
}
ctx := context.WithValue(r.Context(), telegramIDKey, telegramID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// validateInitData parses and validates Telegram Web App initData.
// Returns the telegram user ID on success.
func validateInitData(initData string, botToken string) (int64, error) {
if botToken == "" {
// In dev mode without a bot token, try to parse the user ID anyway.
return parseUserIDFromInitData(initData)
}
values, err := url.ParseQuery(initData)
if err != nil {
return 0, fmt.Errorf("parse initData: %w", err)
}
hash := values.Get("hash")
if hash == "" {
return 0, fmt.Errorf("missing hash in initData")
}
// Build the data-check-string: sort all key=value pairs except "hash", join with \n.
var pairs []string
for k, v := range values {
if k == "hash" {
continue
}
pairs = append(pairs, k+"="+v[0])
}
sort.Strings(pairs)
dataCheckString := strings.Join(pairs, "\n")
// Derive the secret key: HMAC-SHA256 of bot token with key "WebAppData".
secretKeyMac := hmac.New(sha256.New, []byte("WebAppData"))
secretKeyMac.Write([]byte(botToken))
secretKey := secretKeyMac.Sum(nil)
// Compute HMAC-SHA256 of the data-check-string with the secret key.
mac := hmac.New(sha256.New, secretKey)
mac.Write([]byte(dataCheckString))
computedHash := hex.EncodeToString(mac.Sum(nil))
if !hmac.Equal([]byte(computedHash), []byte(hash)) {
return 0, fmt.Errorf("hash mismatch")
}
return parseUserIDFromInitData(initData)
}
// parseUserIDFromInitData extracts the user ID from the initData query string.
func parseUserIDFromInitData(initData string) (int64, error) {
values, err := url.ParseQuery(initData)
if err != nil {
return 0, fmt.Errorf("parse initData: %w", err)
}
// The user field is a JSON string, but we just need the id.
// Format: user={"id":123456789,"first_name":"Name",...}
userStr := values.Get("user")
if userStr == "" {
// Try query_id based format or direct id parameter.
idStr := values.Get("id")
if idStr == "" {
return 0, fmt.Errorf("no user data in initData")
}
return strconv.ParseInt(idStr, 10, 64)
}
// Simple extraction of "id" from JSON without full unmarshal.
// Find "id": followed by a number.
idx := strings.Index(userStr, `"id":`)
if idx == -1 {
return 0, fmt.Errorf("no id field in user data")
}
rest := userStr[idx+5:]
// Trim any whitespace.
rest = strings.TrimSpace(rest)
// Read digits until non-digit.
var numStr string
for _, c := range rest {
if c >= '0' && c <= '9' {
numStr += string(c)
} else {
break
}
}
if numStr == "" {
return 0, fmt.Errorf("invalid id in user data")
}
return strconv.ParseInt(numStr, 10, 64)
}