authHub/0003-fix-logout-register-error.patch

103 lines
4.2 KiB
Diff

From 565b442fa56c93a706d5b2f5224763854b8f42cc Mon Sep 17 00:00:00 2001
From: rearcher <123781007@qq.com>
Date: Fri, 20 Sep 2024 15:11:28 +0800
Subject: [PATCH] fix logout error, fix register error
---
oauth2_provider/app/core/account.py | 43 ++++++++++-------------------
1 file changed, 15 insertions(+), 28 deletions(-)
diff --git a/oauth2_provider/app/core/account.py b/oauth2_provider/app/core/account.py
index 3259704..16038fd 100644
--- a/oauth2_provider/app/core/account.py
+++ b/oauth2_provider/app/core/account.py
@@ -67,8 +67,8 @@ class UserProxy:
if not self._check_user_not_exist(username):
LOGGER.error(f"add user failed, username exists: {username}")
return DATA_EXIST
- self._add_user(username, password, email)
- callback_res = self._register_callback(username)
+ user_info = self._add_user(username, password, email)
+ callback_res = self._register_callback(user_info)
if callback_res != SUCCEED:
return callback_res
db.session.commit()
@@ -80,42 +80,25 @@ class UserProxy:
return DATABASE_INSERT_ERROR
return SUCCEED
- def _register_callback(self, username: str) -> str:
+ def _register_callback(self, user) -> str:
res = SUCCEED
for client in db.session.query(OAuth2Client).distinct(OAuth2Client.client_id).all():
- user_info = self._get_user_info(username, client.client_id)
+ scope = client.client_metadata["scope"].split()
+ user_info = dict()
+ if "username" in scope:
+ user_info["username"] = user.username
+ if "email" in scope:
+ user_info["email"] = user.email
for register_callback_uri in client.register_callback_uris:
response_data = BaseResponse.get_response(
method="Post", url=register_callback_uri, data=user_info, header=self.HEADERS
)
response_status = response_data.get("label")
if response_status != SUCCEED:
- LOGGER.error(f"register redirect failed: {client.client_id}, {username}")
+ LOGGER.error(f"register redirect failed: {client.client_id}, {user.username}")
res = PARTIAL_SUCCEED
return res
- def _get_user_info(self, username: str, client_id: str) -> dict:
- """
- Get user info.
-
- Args:
- username(str): username,
- client_id(str): client id
-
- Returns:
- dict: user info
- """
- client_scopes = db.session.query(OAuth2ClientScopes).filter_by(username=username, client_id=client_id).one()
- user = db.session.query(User).filter_by(username=username).one()
- user_info = dict()
- # user scope, e.g. ["email","username","openid","offline_access"]
- scopes = client_scopes.scopes.split()
- if "username" in scopes:
- user_info["username"] = user.username
- if "email" in scopes:
- user_info["email"] = user.email
- return user_info
-
def _check_user_not_exist(self, username: str) -> bool:
query_res = db.session.query(User).filter_by(username=username).count()
if query_res != 0:
@@ -133,10 +116,14 @@ class UserProxy:
"password": "xxx",
"email": "xxx@xxx.com"
}
+
+ Returns:
+ user: user
"""
password_hash = User.hash_password(password)
user = User(username=username, password=password_hash, email=email)
db.session.add(user)
+ return user
def manager_login(self, data) -> Tuple[str, str]:
"""
@@ -283,7 +270,7 @@ class UserProxy:
encrypted_data = encrypted_data.encode('utf-8')
encoded_data = base64.b64encode(encrypted_data)
encrypted_string = encoded_data.decode('utf-8')
- logout_callback_uris = login_record.logout_url.split(",")
+ logout_callback_uris = list(filter(None, login_record.logout_url.split(',')))
for logout_callback_uri in logout_callback_uris:
response_data = BaseResponse.get_response(
method="Post",
--
Gitee