103 lines
4.2 KiB
Diff
103 lines
4.2 KiB
Diff
From 565b442fa56c93a706d5b2f5224763854b8f42cc Mon Sep 17 00:00:00 2001
|
|
From: rearcher <123781007@qq.com>
|
|
Date: Fri, 20 Sep 2024 15:11:28 +0800
|
|
Subject: [PATCH] fix logout error, fix register error
|
|
|
|
---
|
|
oauth2_provider/app/core/account.py | 43 ++++++++++-------------------
|
|
1 file changed, 15 insertions(+), 28 deletions(-)
|
|
|
|
diff --git a/oauth2_provider/app/core/account.py b/oauth2_provider/app/core/account.py
|
|
index 3259704..16038fd 100644
|
|
--- a/oauth2_provider/app/core/account.py
|
|
+++ b/oauth2_provider/app/core/account.py
|
|
@@ -67,8 +67,8 @@ class UserProxy:
|
|
if not self._check_user_not_exist(username):
|
|
LOGGER.error(f"add user failed, username exists: {username}")
|
|
return DATA_EXIST
|
|
- self._add_user(username, password, email)
|
|
- callback_res = self._register_callback(username)
|
|
+ user_info = self._add_user(username, password, email)
|
|
+ callback_res = self._register_callback(user_info)
|
|
if callback_res != SUCCEED:
|
|
return callback_res
|
|
db.session.commit()
|
|
@@ -80,42 +80,25 @@ class UserProxy:
|
|
return DATABASE_INSERT_ERROR
|
|
return SUCCEED
|
|
|
|
- def _register_callback(self, username: str) -> str:
|
|
+ def _register_callback(self, user) -> str:
|
|
res = SUCCEED
|
|
for client in db.session.query(OAuth2Client).distinct(OAuth2Client.client_id).all():
|
|
- user_info = self._get_user_info(username, client.client_id)
|
|
+ scope = client.client_metadata["scope"].split()
|
|
+ user_info = dict()
|
|
+ if "username" in scope:
|
|
+ user_info["username"] = user.username
|
|
+ if "email" in scope:
|
|
+ user_info["email"] = user.email
|
|
for register_callback_uri in client.register_callback_uris:
|
|
response_data = BaseResponse.get_response(
|
|
method="Post", url=register_callback_uri, data=user_info, header=self.HEADERS
|
|
)
|
|
response_status = response_data.get("label")
|
|
if response_status != SUCCEED:
|
|
- LOGGER.error(f"register redirect failed: {client.client_id}, {username}")
|
|
+ LOGGER.error(f"register redirect failed: {client.client_id}, {user.username}")
|
|
res = PARTIAL_SUCCEED
|
|
return res
|
|
|
|
- def _get_user_info(self, username: str, client_id: str) -> dict:
|
|
- """
|
|
- Get user info.
|
|
-
|
|
- Args:
|
|
- username(str): username,
|
|
- client_id(str): client id
|
|
-
|
|
- Returns:
|
|
- dict: user info
|
|
- """
|
|
- client_scopes = db.session.query(OAuth2ClientScopes).filter_by(username=username, client_id=client_id).one()
|
|
- user = db.session.query(User).filter_by(username=username).one()
|
|
- user_info = dict()
|
|
- # user scope, e.g. ["email","username","openid","offline_access"]
|
|
- scopes = client_scopes.scopes.split()
|
|
- if "username" in scopes:
|
|
- user_info["username"] = user.username
|
|
- if "email" in scopes:
|
|
- user_info["email"] = user.email
|
|
- return user_info
|
|
-
|
|
def _check_user_not_exist(self, username: str) -> bool:
|
|
query_res = db.session.query(User).filter_by(username=username).count()
|
|
if query_res != 0:
|
|
@@ -133,10 +116,14 @@ class UserProxy:
|
|
"password": "xxx",
|
|
"email": "xxx@xxx.com"
|
|
}
|
|
+
|
|
+ Returns:
|
|
+ user: user
|
|
"""
|
|
password_hash = User.hash_password(password)
|
|
user = User(username=username, password=password_hash, email=email)
|
|
db.session.add(user)
|
|
+ return user
|
|
|
|
def manager_login(self, data) -> Tuple[str, str]:
|
|
"""
|
|
@@ -283,7 +270,7 @@ class UserProxy:
|
|
encrypted_data = encrypted_data.encode('utf-8')
|
|
encoded_data = base64.b64encode(encrypted_data)
|
|
encrypted_string = encoded_data.decode('utf-8')
|
|
- logout_callback_uris = login_record.logout_url.split(",")
|
|
+ logout_callback_uris = list(filter(None, login_record.logout_url.split(',')))
|
|
for logout_callback_uri in logout_callback_uris:
|
|
response_data = BaseResponse.get_response(
|
|
method="Post",
|
|
--
|
|
Gitee
|
|
|