第五章:JWT 认证中间件

博客v1.0系列教程(Rust)博客 v1.0 系列教程 (Rust)

5.1 JWT 依赖

[dependencies]
jsonwebtoken = "9"
bcrypt = "0.16"

5.2 Token 生成与验证

// src/common/auth.rs
use axum::{
    extract::{FromRequestParts, Request},
    http::{request::Parts, StatusCode},
    middleware::{self, Next},
    response::{IntoResponse, Response},
    Json, RequestPartsExt,
};
use axum_extra::headers::{authorization::Bearer, Authorization};
use axum_extra::TypedHeader;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
    pub sub: String,        // 用户 ID
    pub role: i16,          // 角色
    pub exp: usize,         // 过期时间
    pub iat: usize,         // 签发时间
}

pub fn generate_token(
    user_id: i64,
    role: i16,
    secret: &str,
) -> Result<String, jsonwebtoken::errors::Error> {
    let now = chrono::Utc::now().timestamp() as usize;

    let claims = Claims {
        sub: user_id.to_string(),
        role,
        exp: now + 86400,   // 24 小时过期
        iat: now,
    };

    encode(
        &Header::default(),
        &claims,
        &EncodingKey::from_secret(secret.as_bytes()),
    )
}

pub fn validate_token(
    token: &str,
    secret: &str,
) -> Result<Claims, jsonwebtoken::errors::Error> {
    let token_data = decode::<Claims>(
        token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::default(),
    )?;

    Ok(token_data.claims)
}

5.3 密码哈希

使用 bcrypt 进行密码哈希存储:

use bcrypt::{hash, verify, DEFAULT_COST};

pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
    hash(password, DEFAULT_COST)
}

pub fn verify_password(password: &str, hash: &str) -> Result<bool, bcrypt::BcryptError> {
    verify(password, hash)
}

5.4 Axum 认证中间件

/// 认证中间件 — 从 Authorization header 提取并验证 JWT
pub async fn auth_middleware(
    State(config): State<AppConfig>,
    mut req: Request,
    next: Next,
) -> Response {
    // 从请求头提取 Token
    let auth_header = req
        .headers()
        .get("Authorization")
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.strip_prefix("Bearer "));

    let token = match auth_header {
        Some(t) => t,
        None => {
            return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({
                "error": "Missing authorization token"
            }))).into_response();
        }
    };

    // 验证 Token
    match validate_token(token, &config.jwt_secret) {
        Ok(claims) => {
            // 将用户信息注入请求扩展
            req.extensions_mut().insert(claims);
            next.run(req).await
        }
        Err(_) => {
            (StatusCode::UNAUTHORIZED, Json(serde_json::json!({
                "error": "Invalid or expired token"
            }))).into_response()
        }
    }
}

5.5 提取认证用户

/// 从请求中提取当前认证用户
#[derive(Debug, Clone)]
pub struct AuthUser {
    pub id: i64,
    pub role: i16,
}

impl<S> FromRequestParts<S> for AuthUser
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, &'static str);

    async fn from_request_parts(
        parts: &mut Parts,
        _state: &S,
    ) -> Result<Self, Self::Rejection> {
        let claims = parts
            .extensions
            .get::<Claims>()
            .ok_or((StatusCode::UNAUTHORIZED, "Not authenticated"))?;

        Ok(AuthUser {
            id: claims.sub.parse().map_err(|_| {
                (StatusCode::INTERNAL_SERVER_ERROR, "Invalid user ID")
            })?,
            role: claims.role,
        })
    }
}

5.6 登录接口

async fn login(
    State(pool): State<PgPool>,
    State(config): State<AppConfig>,
    Json(credentials): Json<LoginDto>,
) -> Result<Json<ApiResult<LoginResponse>>, AppError> {
    let user = sqlx::query_as::<_, User>(
        "SELECT * FROM user_models WHERE username = $1"
    )
    .bind(&credentials.username)
    .fetch_optional(&pool)
    .await?
    .ok_or_else(|| AppError::Unauthorized("Invalid credentials".into()))?;

    let password = user.password.ok_or_else(|| {
        AppError::Unauthorized("Invalid credentials".into())
    })?;

    if !verify_password(&credentials.password, &password)? {
        return Err(AppError::Unauthorized("Invalid credentials".into()));
    }

    let token = generate_token(user.id, user.role.unwrap_or(1), &config.jwt_secret)?;

    Ok(Json(ApiResult::success(LoginResponse {
        token,
        user_id: user.id,
        username: user.username,
    })))
}

5.7 路由保护

fn create_router() -> Router {
    let public_routes = Router::new()
        .route("/api/user/login", post(login))
        .route("/api/article", get(list_articles));

    let protected_routes = Router::new()
        .route("/api/article", post(create_article))
        .route("/api/article/{id}", delete(delete_article))
        .layer(middleware::from_fn_with_state(
            config.clone(),
            auth_middleware,
        ));

    Router::new()
        .merge(public_routes)
        .merge(protected_routes)
        .layer(CorsLayer::permissive())
}

下一章将实现 RESTful API 设计。

rustjwtauthmiddlewareaxum