python-oauthlib/backport-Ensure-expires_at-is-always-int.patch

118 lines
5.1 KiB
Diff
Raw Normal View History

From d4b6699f8ccb608152b764919e0bd3d38a7b171f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sindri=20Gu=C3=B0mundsson?= <sindrigudmundsson@gmail.com>
Date: Mon, 22 Aug 2022 16:32:14 +0000
Subject: [PATCH] Ensure expires_at is always int
As discussed in #745
---
oauthlib/oauth2/rfc6749/clients/base.py | 4 +--
oauthlib/oauth2/rfc6749/parameters.py | 5 +++-
tests/oauth2/rfc6749/clients/test_base.py | 33 ++++++++++++++++++++++
.../rfc6749/clients/test_service_application.py | 2 +-
4 files changed, 40 insertions(+), 4 deletions(-)
diff --git a/oauthlib/oauth2/rfc6749/clients/base.py b/oauthlib/oauth2/rfc6749/clients/base.py
index d5eb0cc..1d12638 100644
--- a/oauthlib/oauth2/rfc6749/clients/base.py
+++ b/oauthlib/oauth2/rfc6749/clients/base.py
@@ -589,11 +589,11 @@ class Client:
if 'expires_in' in response:
self.expires_in = response.get('expires_in')
- self._expires_at = time.time() + int(self.expires_in)
+ self._expires_at = round(time.time()) + int(self.expires_in)
if 'expires_at' in response:
try:
- self._expires_at = int(response.get('expires_at'))
+ self._expires_at = round(float(response.get('expires_at')))
except:
self._expires_at = None
diff --git a/oauthlib/oauth2/rfc6749/parameters.py b/oauthlib/oauth2/rfc6749/parameters.py
index 8f6ce2c..0f0f423 100644
--- a/oauthlib/oauth2/rfc6749/parameters.py
+++ b/oauthlib/oauth2/rfc6749/parameters.py
@@ -345,7 +345,7 @@ def parse_implicit_response(uri, state=None, scope=None):
params['scope'] = scope_to_list(params['scope'])
if 'expires_in' in params:
- params['expires_at'] = time.time() + int(params['expires_in'])
+ params['expires_at'] = round(time.time()) + int(params['expires_in'])
if state and params.get('state', None) != state:
raise ValueError("Mismatching or missing state in params.")
@@ -437,6 +437,9 @@ def parse_token_response(body, scope=None):
else:
params['expires_at'] = time.time() + int(params['expires_in'])
+ if isinstance(params.get('expires_at'), float):
+ params['expires_at'] = round(params['expires_at'])
+
params = OAuth2Token(params, old_scope=scope)
validate_token_parameters(params)
return params
diff --git a/tests/oauth2/rfc6749/clients/test_base.py b/tests/oauth2/rfc6749/clients/test_base.py
index 70a2283..7286b99 100644
--- a/tests/oauth2/rfc6749/clients/test_base.py
+++ b/tests/oauth2/rfc6749/clients/test_base.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import datetime
+from unittest.mock import patch
from oauthlib import common
from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError
@@ -353,3 +354,35 @@ class ClientTest(TestCase):
code_verifier = client.create_code_verifier(length=128)
code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256')
self.assertEqual(code_challenge_s256, client.code_challenge)
+
+ def test_parse_token_response_expires_at_is_int(self):
+ expected_expires_at = 1661185149
+ token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
+ ' "token_type":"example",'
+ ' "expires_at":1661185148.6437678,'
+ ' "scope":"/profile",'
+ ' "example_parameter":"example_value"}')
+
+ client = Client(self.client_id)
+
+ response = client.parse_request_body_response(token_json, scope=["/profile"])
+
+ self.assertEqual(response['expires_at'], expected_expires_at)
+ self.assertEqual(client._expires_at, expected_expires_at)
+
+ @patch('time.time')
+ def test_parse_token_response_generated_expires_at_is_int(self, t):
+ t.return_value = 1661185148.6437678
+ expected_expires_at = round(t.return_value) + 3600
+ token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
+ ' "token_type":"example",'
+ ' "expires_in":3600,'
+ ' "scope":"/profile",'
+ ' "example_parameter":"example_value"}')
+
+ client = Client(self.client_id)
+
+ response = client.parse_request_body_response(token_json, scope=["/profile"])
+
+ self.assertEqual(response['expires_at'], expected_expires_at)
+ self.assertEqual(client._expires_at, expected_expires_at)
diff --git a/tests/oauth2/rfc6749/clients/test_service_application.py b/tests/oauth2/rfc6749/clients/test_service_application.py
index b97d855..84361d8 100644
--- a/tests/oauth2/rfc6749/clients/test_service_application.py
+++ b/tests/oauth2/rfc6749/clients/test_service_application.py
@@ -166,7 +166,7 @@ mfvGGg3xNjTMO7IdrwIDAQAB
@patch('time.time')
def test_parse_token_response(self, t):
t.return_value = time()
- self.token['expires_at'] = self.token['expires_in'] + t.return_value
+ self.token['expires_at'] = self.token['expires_in'] + round(t.return_value)
client = ServiceApplicationClient(self.client_id)
--
2.9.3.windows.1