Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions tests/test_audit_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_succeeds(self, capture_and_mock_http_client_request):
)

response = self.audit_logs.create_event(
organization_id, event, "test_123456"
organization_id=organization_id,
event=event,
idempotency_key="test_123456",
)

assert request_kwargs["json"] == {
Expand All @@ -97,7 +99,9 @@ def test_sends_idempotency_key(
)

response = self.audit_logs.create_event(
organization_id, mock_audit_log_event, idempotency_key
organization_id=organization_id,
event=mock_audit_log_event,
idempotency_key=idempotency_key,
)

assert request_kwargs["headers"]["idempotency-key"] == idempotency_key
Expand All @@ -116,7 +120,9 @@ def test_throws_unauthorized_exception(
)

with pytest.raises(AuthenticationException) as excinfo:
self.audit_logs.create_event(organization_id, mock_audit_log_event)
self.audit_logs.create_event(
organization_id=organization_id, event=mock_audit_log_event
)
assert "(message=Unauthorized, request_id=a-request-id)" == str(
excinfo.value
)
Expand All @@ -138,7 +144,9 @@ def test_throws_badrequest_excpetion(
)

with pytest.raises(BadRequestException) as excinfo:
self.audit_logs.create_event(organization_id, mock_audit_log_event)
self.audit_logs.create_event(
organization_id=organization_id, event=mock_audit_log_event
)
assert excinfo.code == "invalid_audit_log"
assert excinfo.errors == ["error in a field"]
assert (
Expand All @@ -165,7 +173,9 @@ def test_succeeds(self, mock_http_client_with_response):
mock_http_client_with_response(self.http_client, expected_payload, 201)

response = self.audit_logs.create_export(
organization_id, range_start, range_end
organization_id=organization_id,
range_start=range_start,
range_end=range_end,
)

assert response.dict() == expected_payload
Expand Down Expand Up @@ -216,7 +226,11 @@ def test_throws_unauthorized_excpetion(self, mock_http_client_with_response):
)

with pytest.raises(AuthenticationException) as excinfo:
self.audit_logs.create_export(organization_id, range_start, range_end)
self.audit_logs.create_export(
organization_id=organization_id,
range_start=range_start,
range_end=range_end,
)
assert "(message=Unauthorized, request_id=a-request-id)" == str(
excinfo.value
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_directory_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_list_users_with_group(self, mock_users, mock_http_client_with_response)
http_client=self.http_client, status_code=200, response_dict=mock_users
)

users = self.directory_sync.list_users(group="directory_grp_id")
users = self.directory_sync.list_users(group_id="directory_grp_id")

assert list_data_to_dicts(users.data) == mock_users["data"]

Expand Down
10 changes: 5 additions & 5 deletions tests/test_mfa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from workos.mfa import Mfa
import pytest

from workos.utils.http_client import SyncHTTPClient


Expand Down Expand Up @@ -152,7 +151,7 @@ def test_enroll_factor_sms_success(
mock_http_client_with_response(
self.http_client, mock_enroll_factor_response_sms, 200
)
enroll_factor = self.mfa.enroll_factor("sms", None, None, "9204448888")
enroll_factor = self.mfa.enroll_factor(type="sms", phone_number="9204448888")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed silly to continue to pass the totp fields here. Removed them.

assert enroll_factor.dict() == mock_enroll_factor_response_sms

def test_enroll_factor_totp_success(
Expand All @@ -162,7 +161,7 @@ def test_enroll_factor_totp_success(
self.http_client, mock_enroll_factor_response_totp, 200
)
enroll_factor = self.mfa.enroll_factor(
"totp", totp_issuer="testissuer", totp_user="testuser"
type="totp", totp_issuer="testissuer", totp_user="testuser"
)
assert enroll_factor.dict() == mock_enroll_factor_response_totp

Expand Down Expand Up @@ -196,7 +195,7 @@ def test_challenge_success(
self.http_client, mock_challenge_factor_response, 200
)
challenge_factor = self.mfa.challenge_factor(
"auth_factor_01FXNWW32G7F3MG8MYK5D1HJJM"
authentication_factor_id="auth_factor_01FXNWW32G7F3MG8MYK5D1HJJM"
)
assert challenge_factor.dict() == mock_challenge_factor_response

Expand All @@ -207,6 +206,7 @@ def test_verify_success(
self.http_client, mock_verify_challenge_response, 200
)
verify_challenge = self.mfa.verify_challenge(
"auth_challenge_01FXNXH8Y2K3YVWJ10P139A6DT", "093647"
authentication_challenge_id="auth_challenge_01FXNXH8Y2K3YVWJ10P139A6DT",
code="093647",
)
assert verify_challenge.dict() == mock_verify_challenge_response
2 changes: 0 additions & 2 deletions tests/test_organizations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import datetime

import pytest

from tests.utils.list_resource import list_data_to_dicts, list_response_of
from workos.organizations import Organizations
from tests.utils.fixtures.mock_organization import MockOrganization
Expand Down
1 change: 0 additions & 1 deletion tests/test_passwordless.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from workos.passwordless import Passwordless
from workos.utils.http_client import SyncHTTPClient

Expand Down
12 changes: 8 additions & 4 deletions tests/test_portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def mock_portal_link(self):
def test_generate_link_sso(self, mock_portal_link, mock_http_client_with_response):
mock_http_client_with_response(self.http_client, mock_portal_link, 201)

response = self.portal.generate_link("sso", "org_01EHQMYV6MBK39QC5PZXHY59C3")
response = self.portal.generate_link(
intent="sso", organization_id="org_01EHQMYV6MBK39QC5PZXHY59C3"
)

assert response.link == "https://id.workos.com/portal/launch?secret=secret"

Expand All @@ -28,7 +30,9 @@ def test_generate_link_dsync(
):
mock_http_client_with_response(self.http_client, mock_portal_link, 201)

response = self.portal.generate_link("dsync", "org_01EHQMYV6MBK39QC5PZXHY59C3")
response = self.portal.generate_link(
intent="dsync", organization_id="org_01EHQMYV6MBK39QC5PZXHY59C3"
)

assert response.link == "https://id.workos.com/portal/launch?secret=secret"

Expand All @@ -38,7 +42,7 @@ def test_generate_link_audit_logs(
mock_http_client_with_response(self.http_client, mock_portal_link, 201)

response = self.portal.generate_link(
"audit_logs", "org_01EHQMYV6MBK39QC5PZXHY59C3"
intent="audit_logs", organization_id="org_01EHQMYV6MBK39QC5PZXHY59C3"
)

assert response.link == "https://id.workos.com/portal/launch?secret=secret"
Expand All @@ -49,7 +53,7 @@ def test_generate_link_log_streams(
mock_http_client_with_response(self.http_client, mock_portal_link, 201)

response = self.portal.generate_link(
"log_streams", "org_01EHQMYV6MBK39QC5PZXHY59C3"
intent="log_streams", organization_id="org_01EHQMYV6MBK39QC5PZXHY59C3"
)

assert response.link == "https://id.workos.com/portal/launch?secret=secret"
2 changes: 1 addition & 1 deletion tests/test_user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_update_user(self, mock_user, capture_and_mock_http_client_request):
"password": "password",
}
user = self.user_management.update_user(
"user_01H7ZGXFP5C6BBQY6Z7277ZCT0", **params
user_id="user_01H7ZGXFP5C6BBQY6Z7277ZCT0", **params
)

assert request_kwargs["url"].endswith("users/user_01H7ZGXFP5C6BBQY6Z7277ZCT0")
Expand Down
41 changes: 22 additions & 19 deletions tests/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def test_unable_to_extract_timestamp(
):
with pytest.raises(ValueError) as err:
self.webhooks.verify_event(
mock_event_body.encode("utf-8"),
mock_header_no_timestamp,
mock_secret,
180,
payload=mock_event_body.encode("utf-8"),
event_signature=mock_header_no_timestamp,
secret=mock_secret,
tolerance=180,
)
assert "Unable to extract timestamp and signature hash from header" in str(
err.value
Expand All @@ -61,19 +61,22 @@ def test_timestamp_outside_threshold(
):
with pytest.raises(ValueError) as err:
self.webhooks.verify_event(
mock_event_body.encode("utf-8"), mock_header, mock_secret, 0
payload=mock_event_body.encode("utf-8"),
event_signature=mock_header,
secret=mock_secret,
tolerance=0,
)
assert "Timestamp outside the tolerance zone" in str(err.value)

def test_sig_hash_does_not_match_expected_sig_length(self, mock_sig_hash):
result = self.webhooks.constant_time_compare(
result = self.webhooks._constant_time_compare(
mock_sig_hash,
"df25b6efdd39d82e7b30e75ea19655b306860ad5cde3eeaeb6f1dfea029ea25",
)
assert result == False

def test_sig_hash_does_not_match_expected_sig_value(self, mock_sig_hash):
result = self.webhooks.constant_time_compare(
result = self.webhooks._constant_time_compare(
mock_sig_hash,
"df25b6efdd39d82e7b30e75ea19655b306860ad5cde3eeaeb6f1dfea029ea252",
)
Expand All @@ -84,10 +87,10 @@ def test_passed_expected_event_validation(
):
try:
webhook = self.webhooks.verify_event(
mock_event_body.encode("utf-8"),
mock_header,
mock_secret,
99999999999999,
payload=mock_event_body.encode("utf-8"),
event_signature=mock_header,
secret=mock_secret,
tolerance=99999999999999,
)
assert type(webhook).__name__ == "ConnectionActivatedWebhook"
except BaseException:
Expand All @@ -100,10 +103,10 @@ def test_sign_hash_does_not_match_expected_sig_hash_verify_header(
):
with pytest.raises(ValueError) as err:
self.webhooks.verify_header(
mock_event_body.encode("utf-8"),
mock_header,
mock_bad_secret,
99999999999999,
event_body=mock_event_body.encode("utf-8"),
event_signature=mock_header,
secret=mock_bad_secret,
tolerance=99999999999999,
)
assert (
"Signature hash does not match the expected signature hash for payload"
Expand All @@ -114,10 +117,10 @@ def test_unrecognized_webhook_type_returns_untyped_webhook(
self, mock_unknown_webhook_body, mock_unknown_webhook_header, mock_secret
):
result = self.webhooks.verify_event(
mock_unknown_webhook_body.encode("utf-8"),
mock_unknown_webhook_header,
mock_secret,
99999999999999,
payload=mock_unknown_webhook_body.encode("utf-8"),
event_signature=mock_unknown_webhook_header,
secret=mock_secret,
tolerance=99999999999999,
)
assert type(result).__name__ == "UntypedWebhook"
assert result.dict() == json.loads(mock_unknown_webhook_body)
2 changes: 1 addition & 1 deletion workos/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AsyncClient(BaseClient):
_user_management: AsyncUserManagement
_webhooks: WebhooksModule

def __init__(self, base_url: str, version: str, timeout: int):
def __init__(self, *, base_url: str, version: str, timeout: int):
self._http_client = AsyncHTTPClient(
base_url=base_url, version=version, timeout=timeout
)
Expand Down
4 changes: 4 additions & 0 deletions workos/audit_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
class AuditLogsModule(Protocol):
def create_event(
self,
*,
organization_id: str,
event: AuditLogEvent,
idempotency_key: Optional[str] = None,
) -> None: ...

def create_export(
self,
*,
organization_id: str,
range_start: str,
range_end: str,
Expand All @@ -44,6 +46,7 @@ def __init__(self, http_client: SyncHTTPClient):

def create_event(
self,
*,
organization_id: str,
event: AuditLogEvent,
idempotency_key: Optional[str] = None,
Expand Down Expand Up @@ -71,6 +74,7 @@ def create_event(

def create_export(
self,
*,
organization_id: str,
range_start: str,
range_end: str,
Expand Down
2 changes: 1 addition & 1 deletion workos/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SyncClient(BaseClient):
_user_management: UserManagement
_webhooks: Webhooks

def __init__(self, base_url: str, version: str, timeout: int):
def __init__(self, *, base_url: str, version: str, timeout: int):
self._http_client = SyncHTTPClient(
base_url=base_url, version=version, timeout=timeout
)
Expand Down
Loading