From 565b442fa56c93a706d5b2f5224763854b8f42cc Mon Sep 17 00:00:00 2001 From: rearcher <123781007@qq.com> Date: Fri, 20 Sep 2024 15:11:28 +0800 Subject: [PATCH] fix logout error, fix register error --- oauth2_provider/app/core/account.py | 43 ++++++++++------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/oauth2_provider/app/core/account.py b/oauth2_provider/app/core/account.py index 3259704..16038fd 100644 --- a/oauth2_provider/app/core/account.py +++ b/oauth2_provider/app/core/account.py @@ -67,8 +67,8 @@ class UserProxy: if not self._check_user_not_exist(username): LOGGER.error(f"add user failed, username exists: {username}") return DATA_EXIST - self._add_user(username, password, email) - callback_res = self._register_callback(username) + user_info = self._add_user(username, password, email) + callback_res = self._register_callback(user_info) if callback_res != SUCCEED: return callback_res db.session.commit() @@ -80,42 +80,25 @@ class UserProxy: return DATABASE_INSERT_ERROR return SUCCEED - def _register_callback(self, username: str) -> str: + def _register_callback(self, user) -> str: res = SUCCEED for client in db.session.query(OAuth2Client).distinct(OAuth2Client.client_id).all(): - user_info = self._get_user_info(username, client.client_id) + scope = client.client_metadata["scope"].split() + user_info = dict() + if "username" in scope: + user_info["username"] = user.username + if "email" in scope: + user_info["email"] = user.email for register_callback_uri in client.register_callback_uris: response_data = BaseResponse.get_response( method="Post", url=register_callback_uri, data=user_info, header=self.HEADERS ) response_status = response_data.get("label") if response_status != SUCCEED: - LOGGER.error(f"register redirect failed: {client.client_id}, {username}") + LOGGER.error(f"register redirect failed: {client.client_id}, {user.username}") res = PARTIAL_SUCCEED return res - def _get_user_info(self, username: str, client_id: str) -> dict: - """ - Get user info. - - Args: - username(str): username, - client_id(str): client id - - Returns: - dict: user info - """ - client_scopes = db.session.query(OAuth2ClientScopes).filter_by(username=username, client_id=client_id).one() - user = db.session.query(User).filter_by(username=username).one() - user_info = dict() - # user scope, e.g. ["email","username","openid","offline_access"] - scopes = client_scopes.scopes.split() - if "username" in scopes: - user_info["username"] = user.username - if "email" in scopes: - user_info["email"] = user.email - return user_info - def _check_user_not_exist(self, username: str) -> bool: query_res = db.session.query(User).filter_by(username=username).count() if query_res != 0: @@ -133,10 +116,14 @@ class UserProxy: "password": "xxx", "email": "xxx@xxx.com" } + + Returns: + user: user """ password_hash = User.hash_password(password) user = User(username=username, password=password_hash, email=email) db.session.add(user) + return user def manager_login(self, data) -> Tuple[str, str]: """ @@ -283,7 +270,7 @@ class UserProxy: encrypted_data = encrypted_data.encode('utf-8') encoded_data = base64.b64encode(encrypted_data) encrypted_string = encoded_data.decode('utf-8') - logout_callback_uris = login_record.logout_url.split(",") + logout_callback_uris = list(filter(None, login_record.logout_url.split(','))) for logout_callback_uri in logout_callback_uris: response_data = BaseResponse.get_response( method="Post", -- Gitee