Source code for mindroot.coreplugins.jwt_auth.middleware

from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt
from datetime import datetime, timedelta
from lib.route_decorators import public_routes, public_route, public_static
from coreplugins.api_keys import api_key_manager
from lib.providers.services import service_manager
import os
from lib.session_files import load_session_data
from lib.utils.debug import debug_box
import secrets
from pathlib import Path
import re

[docs] def get_or_create_jwt_secret(): secret_key = os.environ.get("JWT_SECRET_KEY", None) if secret_key: print("JWT_SECRET_KEY found in environment variables") return secret_key # Check if .env file exists and contains JWT_SECRET_KEY env_path = Path.cwd() / ".env" if env_path.exists(): with open(env_path, 'r') as f: lines = f.readlines() for line in lines: if line.strip().startswith('JWT_SECRET_KEY='): # Extract the key value key_value = line.strip().split('=', 1)[1] if key_value: print(f"JWT_SECRET_KEY found in {env_path}") # Also set it in environment for current session os.environ['JWT_SECRET_KEY'] = key_value return key_value # If we get here, no key was found anywhere, so generate one print("JWT_SECRET_KEY not found, generating new key...") secret_key = secrets.token_urlsafe(32) # Save to .env file # Check if file exists and needs a newline before appending needs_newline = False if env_path.exists() and env_path.stat().st_size > 0: with open(env_path, 'rb') as f: f.seek(-1, 2) # Go to last byte last_char = f.read(1) needs_newline = last_char != b'\n' with open(env_path, 'a') as f: if needs_newline: f.write('\n') f.write(f"JWT_SECRET_KEY={secret_key}\n") print(f"Generated new JWT_SECRET_KEY and saved to {env_path}") os.environ['JWT_SECRET_KEY'] = secret_key return secret_key
# Get or create the secret key SECRET_KEY = get_or_create_jwt_secret() print(f"JWT_SECRET_KEY loaded successfully") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 10080 # 1 week security = HTTPBearer()
[docs] def create_access_token(data: dict): to_encode = data.copy() expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt
[docs] def decode_token(token: str): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except jwt.ExpiredSignatureError: print("Token has expired") return False except jwt.InvalidTokenError: print("Invalid token") return False
[docs] def path_matches_pattern(request_path: str, route_pattern: str) -> bool: """ Check if a request path matches a route pattern with parameters. Examples: - path_matches_pattern('/chat/embed/abc123', '/chat/embed/{token}') -> True - path_matches_pattern('/chat/widget/xyz/session', '/chat/widget/{token}/session') -> True - path_matches_pattern('/login', '/login') -> True """ # Handle exact matches first if request_path == route_pattern: return True # Convert FastAPI route pattern to regex # Replace {param} with regex pattern that matches any non-slash characters regex_pattern = re.sub(r'\{[^}]+\}', r'[^/]+', route_pattern) # Escape other regex special characters regex_pattern = regex_pattern.replace('.', '\\.') # Add start and end anchors regex_pattern = f'^{regex_pattern}$' try: return bool(re.match(regex_pattern, request_path)) except re.error: # If regex compilation fails, fall back to exact match return request_path == route_pattern
[docs] def is_public_route(request_path: str) -> bool: """ Check if a request path matches any registered public route pattern. """ # Check exact matches and pattern matches for route_pattern in public_routes: if path_matches_pattern(request_path, route_pattern): return True # Check special cases if request_path.startswith('/reset-password'): return True return False
[docs] async def middleware(request: Request, call_next): try: print('-------------------------- auth middleware ----------------------------') print('Request URL:', request.url.path) if request.url.path.endswith('events'): debug_box("events:" + request.url.path) # Check for API key in query parameters api_key = request.query_params.get('api_key') if api_key: print('Found API key in query parameters') key_data = api_key_manager.validate_key(api_key) if key_data: print("Validated API Key, key_data is", key_data) username = key_data['username'] print("Trying to get user data") user_data = await service_manager.get_user_data(username) if user_data: request.state.user = user_data # Create JWT token for persistent session token = create_access_token({"sub": username}) request.state.access_token = token # Get response and set cookie response = await call_next(request) return response else: print(f"User {username} for key {api_key} not found") return JSONResponse( status_code=403, content={"detail": f"User '{username}' for API key {api_key} not found"} ) else: print("Could not validate API key, key_data returned as", key_data) return JSONResponse( status_code=403, content={"detail": "Invalid API key"} ) if request.url.path.startswith("/imgs/") or request.url.path.startswith("/manual/"): return await call_next(request) try: path_parts = request.url.path.split('/') # filter empty "" strings path_parts = list(filter(None, path_parts)) print(f"Request path split is {path_parts}") plugin_name = path_parts[0] static_part = path_parts[1] filename = path_parts[-1] print(f"Checking for static file: {plugin_name} {static_part} {filename}") if static_part == 'static': if filename.endswith('.js') or filename.endswith('.css') or filename.endswith('.png') or filename.endswith('.mp4') or filename.endswith('.gif'): print('Static file requested:', filename) return await call_next(request) except Exception as e: print("Error checking for static file", e) pass print("Did not find static file") # Use the improved public route checking if is_public_route(request.url.path): print('Public route: ', request.url.path) return await call_next(request) elif any([request.url.path.startswith(path) for path in public_static]): return await call_next(request) else: print('Not a public route: ', request.url.path) print("public routes:", public_routes) # Check for token in cookies first token = request.cookies.get("access_token") #token = None if token: print("Trying to decode token..") payload = decode_token(token) if payload: # Get username from token username = payload['sub'] user_data = await service_manager.get_user_data(username) request.state.user = user_data if user_data: return await call_next(request) else: print("User data not found, redirecting to login..") return RedirectResponse(url="/login") else: print("Invalid or expired token, redirecting to login..") return RedirectResponse(url="/login") else: print("..Did not find token in cookies..") try: print("Trying bearer token..") token = await security(request) except HTTPException as e: print('Bearer header: No valid token found: ', e) print("Trying session context..") try: session_id = request.url.path.split('/')[-1] token = await load_session_data(session_id, "access_token") if token: print("Retrieved token from session file") print(token) else: print("No token found in session file") except Exception as e: print("Error loading session data") print(e) if token: if hasattr(token, 'credentials'): payload = decode_token(token.credentials) else: payload = decode_token(token) if payload: username = payload['sub'] user_data = await service_manager.get_user_data(username) if user_data: request.state.user = user_data return await call_next(request) else: print("User data not found, redirecting to login..") return RedirectResponse(url="/login") else: print("Invalid or expired token, redirecting to login..") return RedirectResponse(url="/login") print('No valid token found') return RedirectResponse(url="/login") except HTTPException as e: print('HTTPException:', e) return RedirectResponse(url="/login") except Exception as e: print('Error:', e) if 'does not exist' in str(e).lower() or 'not found' in str(e).lower(): return JSONResponse( status_code=404, content={"detail": f"Resource not found: {str(e)}"} ) import traceback traceback.print_exc() return JSONResponse( status_code=500, content={"detail": f"Internal server error: {str(e)}"} ) response = await call_next(request) return response