You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
5.5 KiB
Go
205 lines
5.5 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/denisovdennis/autohero/internal/storage"
|
|
)
|
|
|
|
// contextKey is an unexported type for context keys in this package.
|
|
type contextKey string
|
|
|
|
const telegramIDKey contextKey = "telegram_id"
|
|
|
|
// TelegramIDFromContext extracts the Telegram user ID from the request context.
|
|
func TelegramIDFromContext(ctx context.Context) (int64, bool) {
|
|
id, ok := ctx.Value(telegramIDKey).(int64)
|
|
return id, ok
|
|
}
|
|
|
|
// AuthHandler handles Telegram authentication.
|
|
type AuthHandler struct {
|
|
botToken string
|
|
store *storage.HeroStore
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewAuthHandler creates a new auth handler.
|
|
func NewAuthHandler(botToken string, store *storage.HeroStore, logger *slog.Logger) *AuthHandler {
|
|
return &AuthHandler{
|
|
botToken: botToken,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// TelegramAuth validates Telegram initData, creates hero if first time, returns hero ID.
|
|
// POST /api/v1/auth/telegram
|
|
func (h *AuthHandler) TelegramAuth(w http.ResponseWriter, r *http.Request) {
|
|
initData := r.Header.Get("X-Telegram-Init-Data")
|
|
if initData == "" {
|
|
initData = r.URL.Query().Get("initData")
|
|
}
|
|
if initData == "" {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
|
"error": "missing initData",
|
|
})
|
|
return
|
|
}
|
|
|
|
telegramID, err := validateInitData(initData, h.botToken)
|
|
if err != nil {
|
|
h.logger.Warn("telegram auth failed", "error", err)
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
|
"error": "invalid initData: " + err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
hero, err := h.store.GetOrCreate(r.Context(), telegramID, "Hero")
|
|
if err != nil {
|
|
h.logger.Error("failed to get or create hero", "telegram_id", telegramID, "error", err)
|
|
writeJSON(w, http.StatusInternalServerError, map[string]string{
|
|
"error": "failed to load hero",
|
|
})
|
|
return
|
|
}
|
|
|
|
h.logger.Info("telegram auth success", "telegram_id", telegramID, "hero_id", hero.ID)
|
|
|
|
writeJSON(w, http.StatusOK, map[string]any{
|
|
"heroId": hero.ID,
|
|
"hero": hero,
|
|
})
|
|
}
|
|
|
|
// TelegramAuthMiddleware validates the Telegram initData on every request
|
|
// and injects the telegram_id into the request context.
|
|
func TelegramAuthMiddleware(botToken string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
initData := r.Header.Get("X-Telegram-Init-Data")
|
|
if initData == "" {
|
|
initData = r.URL.Query().Get("initData")
|
|
}
|
|
if initData == "" {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
|
"error": "missing initData",
|
|
})
|
|
return
|
|
}
|
|
|
|
telegramID, err := validateInitData(initData, botToken)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{
|
|
"error": "invalid initData",
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), telegramIDKey, telegramID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// validateInitData parses and validates Telegram Web App initData.
|
|
// Returns the telegram user ID on success.
|
|
func validateInitData(initData string, botToken string) (int64, error) {
|
|
if botToken == "" {
|
|
// In dev mode without a bot token, try to parse the user ID anyway.
|
|
return parseUserIDFromInitData(initData)
|
|
}
|
|
|
|
values, err := url.ParseQuery(initData)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("parse initData: %w", err)
|
|
}
|
|
|
|
hash := values.Get("hash")
|
|
if hash == "" {
|
|
return 0, fmt.Errorf("missing hash in initData")
|
|
}
|
|
|
|
// Build the data-check-string: sort all key=value pairs except "hash", join with \n.
|
|
var pairs []string
|
|
for k, v := range values {
|
|
if k == "hash" {
|
|
continue
|
|
}
|
|
pairs = append(pairs, k+"="+v[0])
|
|
}
|
|
sort.Strings(pairs)
|
|
dataCheckString := strings.Join(pairs, "\n")
|
|
|
|
// Derive the secret key: HMAC-SHA256 of bot token with key "WebAppData".
|
|
secretKeyMac := hmac.New(sha256.New, []byte("WebAppData"))
|
|
secretKeyMac.Write([]byte(botToken))
|
|
secretKey := secretKeyMac.Sum(nil)
|
|
|
|
// Compute HMAC-SHA256 of the data-check-string with the secret key.
|
|
mac := hmac.New(sha256.New, secretKey)
|
|
mac.Write([]byte(dataCheckString))
|
|
computedHash := hex.EncodeToString(mac.Sum(nil))
|
|
|
|
if !hmac.Equal([]byte(computedHash), []byte(hash)) {
|
|
return 0, fmt.Errorf("hash mismatch")
|
|
}
|
|
|
|
return parseUserIDFromInitData(initData)
|
|
}
|
|
|
|
// parseUserIDFromInitData extracts the user ID from the initData query string.
|
|
func parseUserIDFromInitData(initData string) (int64, error) {
|
|
values, err := url.ParseQuery(initData)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("parse initData: %w", err)
|
|
}
|
|
|
|
// The user field is a JSON string, but we just need the id.
|
|
// Format: user={"id":123456789,"first_name":"Name",...}
|
|
userStr := values.Get("user")
|
|
if userStr == "" {
|
|
// Try query_id based format or direct id parameter.
|
|
idStr := values.Get("id")
|
|
if idStr == "" {
|
|
return 0, fmt.Errorf("no user data in initData")
|
|
}
|
|
return strconv.ParseInt(idStr, 10, 64)
|
|
}
|
|
|
|
// Simple extraction of "id" from JSON without full unmarshal.
|
|
// Find "id": followed by a number.
|
|
idx := strings.Index(userStr, `"id":`)
|
|
if idx == -1 {
|
|
return 0, fmt.Errorf("no id field in user data")
|
|
}
|
|
rest := userStr[idx+5:]
|
|
// Trim any whitespace.
|
|
rest = strings.TrimSpace(rest)
|
|
// Read digits until non-digit.
|
|
var numStr string
|
|
for _, c := range rest {
|
|
if c >= '0' && c <= '9' {
|
|
numStr += string(c)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
if numStr == "" {
|
|
return 0, fmt.Errorf("invalid id in user data")
|
|
}
|
|
|
|
return strconv.ParseInt(numStr, 10, 64)
|
|
}
|