第五章:JWT 认证中间件
博客v1.0系列教程(Rust)博客 v1.0 系列教程 (Rust)
5.1 JWT 依赖
[dependencies]
jsonwebtoken = "9"
bcrypt = "0.16"
5.2 Token 生成与验证
// src/common/auth.rs
use axum::{
extract::{FromRequestParts, Request},
http::{request::Parts, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
Json, RequestPartsExt,
};
use axum_extra::headers::{authorization::Bearer, Authorization};
use axum_extra::TypedHeader;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, // 用户 ID
pub role: i16, // 角色
pub exp: usize, // 过期时间
pub iat: usize, // 签发时间
}
pub fn generate_token(
user_id: i64,
role: i16,
secret: &str,
) -> Result<String, jsonwebtoken::errors::Error> {
let now = chrono::Utc::now().timestamp() as usize;
let claims = Claims {
sub: user_id.to_string(),
role,
exp: now + 86400, // 24 小时过期
iat: now,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
}
pub fn validate_token(
token: &str,
secret: &str,
) -> Result<Claims, jsonwebtoken::errors::Error> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::default(),
)?;
Ok(token_data.claims)
}
5.3 密码哈希
使用 bcrypt 进行密码哈希存储:
use bcrypt::{hash, verify, DEFAULT_COST};
pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
hash(password, DEFAULT_COST)
}
pub fn verify_password(password: &str, hash: &str) -> Result<bool, bcrypt::BcryptError> {
verify(password, hash)
}
5.4 Axum 认证中间件
/// 认证中间件 — 从 Authorization header 提取并验证 JWT
pub async fn auth_middleware(
State(config): State<AppConfig>,
mut req: Request,
next: Next,
) -> Response {
// 从请求头提取 Token
let auth_header = req
.headers()
.get("Authorization")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => {
return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({
"error": "Missing authorization token"
}))).into_response();
}
};
// 验证 Token
match validate_token(token, &config.jwt_secret) {
Ok(claims) => {
// 将用户信息注入请求扩展
req.extensions_mut().insert(claims);
next.run(req).await
}
Err(_) => {
(StatusCode::UNAUTHORIZED, Json(serde_json::json!({
"error": "Invalid or expired token"
}))).into_response()
}
}
}
5.5 提取认证用户
/// 从请求中提取当前认证用户
#[derive(Debug, Clone)]
pub struct AuthUser {
pub id: i64,
pub role: i16,
}
impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let claims = parts
.extensions
.get::<Claims>()
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated"))?;
Ok(AuthUser {
id: claims.sub.parse().map_err(|_| {
(StatusCode::INTERNAL_SERVER_ERROR, "Invalid user ID")
})?,
role: claims.role,
})
}
}
5.6 登录接口
async fn login(
State(pool): State<PgPool>,
State(config): State<AppConfig>,
Json(credentials): Json<LoginDto>,
) -> Result<Json<ApiResult<LoginResponse>>, AppError> {
let user = sqlx::query_as::<_, User>(
"SELECT * FROM user_models WHERE username = $1"
)
.bind(&credentials.username)
.fetch_optional(&pool)
.await?
.ok_or_else(|| AppError::Unauthorized("Invalid credentials".into()))?;
let password = user.password.ok_or_else(|| {
AppError::Unauthorized("Invalid credentials".into())
})?;
if !verify_password(&credentials.password, &password)? {
return Err(AppError::Unauthorized("Invalid credentials".into()));
}
let token = generate_token(user.id, user.role.unwrap_or(1), &config.jwt_secret)?;
Ok(Json(ApiResult::success(LoginResponse {
token,
user_id: user.id,
username: user.username,
})))
}
5.7 路由保护
fn create_router() -> Router {
let public_routes = Router::new()
.route("/api/user/login", post(login))
.route("/api/article", get(list_articles));
let protected_routes = Router::new()
.route("/api/article", post(create_article))
.route("/api/article/{id}", delete(delete_article))
.layer(middleware::from_fn_with_state(
config.clone(),
auth_middleware,
));
Router::new()
.merge(public_routes)
.merge(protected_routes)
.layer(CorsLayer::permissive())
}
下一章将实现 RESTful API 设计。
rustjwtauthmiddlewareaxum