Skip to content
72 changes: 53 additions & 19 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,25 @@ def get_token_struct(token: str) -> bytes:

@staticmethod
def get_token(auth_type: str) -> bytes:
"""Get token using the specified authentication type"""
"""Get DDBC token struct for the specified authentication type."""
token_struct, _ = AADAuth._acquire_token(auth_type)
return token_struct

@staticmethod
def get_raw_token(auth_type: str) -> str:
"""Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).

This deliberately does NOT cache the credential or token — each call
creates a new Azure Identity credential instance and requests a token.
A fresh acquisition avoids expired-token errors when bulkcopy() is
called long after the original DDBC connect().
"""
_, raw_token = AADAuth._acquire_token(auth_type)
return raw_token

@staticmethod
def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
"""Internal: acquire token and return (ddbc_struct, raw_jwt)."""
# Import Azure libraries inside method to support test mocking
# pylint: disable=import-outside-toplevel
try:
Expand All @@ -53,30 +71,27 @@ def get_token(auth_type: str) -> bytes:
"interactive": InteractiveBrowserCredential,
}

credential_class = credential_map[auth_type]
credential_class = credential_map.get(auth_type)
if not credential_class:
raise ValueError(
f"Unsupported auth_type '{auth_type}'. " f"Supported: {', '.join(credential_map)}"
)
logger.info(
"get_token: Starting Azure AD authentication - auth_type=%s, credential_class=%s",
auth_type,
credential_class.__name__,
)

try:
logger.debug(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed unnecessary log statements, there is an info log above which captures these

"get_token: Creating credential instance - credential_class=%s",
credential_class.__name__,
)
credential = credential_class()
logger.debug(
"get_token: Requesting token from Azure AD - scope=https://database.windows.net/.default"
)
token = credential.get_token("https://database.windows.net/.default").token
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"get_token: Azure AD token acquired successfully - token_length=%d chars",
len(token),
len(raw_token),
)
return AADAuth.get_token_struct(token)
token_struct = AADAuth.get_token_struct(raw_token)
return token_struct, raw_token
except ClientAuthenticationError as e:
# Re-raise with more specific context about Azure AD authentication failure
logger.error(
"get_token: Azure AD authentication failed - credential_class=%s, error=%s",
credential_class.__name__,
Expand All @@ -88,7 +103,6 @@ def get_token(auth_type: str) -> bytes:
f"user cancellation, network issues, or unsupported configuration."
) from e
except Exception as e:
# Catch any other unexpected exceptions
logger.error(
"get_token: Unexpected error during credential creation - credential_class=%s, error=%s",
credential_class.__name__,
Expand Down Expand Up @@ -180,7 +194,7 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:


def get_auth_token(auth_type: str) -> Optional[bytes]:
"""Get authentication token based on auth type"""
"""Get DDBC authentication token struct based on auth type."""
logger.debug("get_auth_token: Starting - auth_type=%s", auth_type)
if not auth_type:
logger.debug("get_auth_token: No auth_type specified, returning None")
Expand All @@ -202,17 +216,37 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
return None


def extract_auth_type(connection_string: str) -> Optional[str]:
"""Extract Entra ID auth type from a connection string.

Used as a fallback when process_connection_string does not propagate
auth_type (e.g. Windows Interactive where DDBC handles auth natively).
Bulkcopy still needs the auth type to acquire a token via Azure Identity.
"""
auth_map = {
AuthType.INTERACTIVE.value: "interactive",
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
}
for part in connection_string.split(";"):
key, _, value = part.strip().partition("=")
if key.strip().lower() == "authentication":
return auth_map.get(value.strip().lower())
return None


def process_connection_string(
connection_string: str,
) -> Tuple[str, Optional[Dict[int, bytes]]]:
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]:
"""
Process connection string and handle authentication.

Args:
connection_string: The connection string to process

Returns:
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
attrs_before dict if needed, and auth_type string for bulk copy token acquisition

Raises:
ValueError: If the connection string is invalid or empty
Expand Down Expand Up @@ -259,7 +293,7 @@ def process_connection_string(
"process_connection_string: Token authentication configured successfully - auth_type=%s",
auth_type,
)
return ";".join(modified_parameters) + ";", {1256: token_struct}
return ";".join(modified_parameters) + ";", {1256: token_struct}, auth_type
else:
logger.warning(
"process_connection_string: Token acquisition failed, proceeding without token"
Expand All @@ -269,4 +303,4 @@ def process_connection_string(
"process_connection_string: Connection string processing complete - has_auth=%s",
bool(auth_type),
)
return ";".join(modified_parameters) + ";", None
return ";".join(modified_parameters) + ";", None, auth_type
11 changes: 10 additions & 1 deletion mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
ProgrammingError,
NotSupportedError,
)
from mssql_python.auth import process_connection_string
from mssql_python.auth import extract_auth_type, process_connection_string
from mssql_python.constants import ConstantsDDBC, GetInfoConstants
from mssql_python.connection_string_parser import _ConnectionStringParser
from mssql_python.connection_string_builder import _ConnectionStringBuilder
Expand Down Expand Up @@ -263,6 +263,11 @@ def __init__(
},
}

# Auth type for acquiring fresh tokens at bulk copy time.
# We intentionally do NOT cache the token — a fresh one is acquired
# each time bulkcopy() is called to avoid expired-token errors.
self._auth_type = None

# Check if the connection string contains authentication parameters
# This is important for processing the connection string correctly.
# If authentication is specified, it will be processed to handle
Expand All @@ -272,6 +277,10 @@ def __init__(
self.connection_str = connection_result[0]
if connection_result[1]:
self._attrs_before.update(connection_result[1])
# Store auth type so bulkcopy() can acquire a fresh token later.
# On Windows Interactive, process_connection_string returns None
# (DDBC handles auth natively), so fall back to the connection string.
self._auth_type = connection_result[2] or extract_auth_type(self.connection_str)

self._closed = False
self._timeout = timeout
Expand Down
29 changes: 22 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,15 +2607,30 @@ def _bulkcopy(
context = {
"server": params.get("server"),
"database": params.get("database"),
"user_name": params.get("uid", ""),
"trust_server_certificate": trust_cert,
"encryption": encryption,
}

# Extract password separately to avoid storing it in generic context that may be logged
password = params.get("pwd", "")
# Build pycore_context with appropriate authentication.
# For Azure AD: acquire a FRESH token right now instead of reusing
# the one from connect() time — avoids expired-token errors when
# bulkcopy() is called long after the original connection.
pycore_context = dict(context)
pycore_context["password"] = password

if self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection
from mssql_python.auth import AADAuth

raw_token = AADAuth.get_raw_token(self.connection._auth_type)
pycore_context["access_token"] = raw_token
logger.debug(
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
self.connection._auth_type,
)
else:
# SQL Server authentication — use uid/password from connection string
pycore_context["user_name"] = params.get("uid", "")
pycore_context["password"] = params.get("pwd", "")

pycore_connection = None
pycore_cursor = None
Expand Down Expand Up @@ -2653,10 +2668,10 @@ def _bulkcopy(

finally:
# Clear sensitive data to minimize memory exposure
password = ""
if pycore_context:
pycore_context["password"] = ""
pycore_context["user_name"] = ""
pycore_context.pop("password", None)
pycore_context.pop("user_name", None)
pycore_context.pop("access_token", None)
# Clean up bulk copy resources
for resource in (pycore_cursor, pycore_connection):
if resource and hasattr(resource, "close"):
Expand Down
57 changes: 53 additions & 4 deletions tests/test_008_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import pytest
import platform
import sys
from unittest.mock import patch, MagicMock
from mssql_python.auth import (
AADAuth,
process_auth_parameters,
remove_sensitive_params,
get_auth_token,
process_connection_string,
extract_auth_type,
)
from mssql_python.constants import AuthType
import secrets
Expand Down Expand Up @@ -82,6 +84,11 @@ def test_get_token_struct(self):
assert isinstance(token_struct, bytes)
assert len(token_struct) > 4

def test_get_raw_token_default(self):
raw_token = AADAuth.get_raw_token("default")
assert isinstance(raw_token, str)
assert raw_token == SAMPLE_TOKEN

def test_get_token_default(self):
token_struct = AADAuth.get_token("default")
assert isinstance(token_struct, bytes)
Expand Down Expand Up @@ -281,7 +288,7 @@ def test_interactive_auth_windows(self, monkeypatch):
params = ["Authentication=ActiveDirectoryInteractive", "Server=test"]
modified_params, auth_type = process_auth_parameters(params)
assert "Authentication=ActiveDirectoryInteractive" in modified_params
assert auth_type == None
assert auth_type is None

def test_interactive_auth_non_windows(self, monkeypatch):
monkeypatch.setattr(platform, "system", lambda: "Darwin")
Expand Down Expand Up @@ -326,34 +333,37 @@ def test_remove_sensitive_parameters(self):
class TestProcessConnectionString:
def test_process_connection_string_with_default_auth(self):
conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert attrs is not None
assert 1256 in attrs
assert isinstance(attrs[1256], bytes)
assert auth_type == "default"

def test_process_connection_string_no_auth(self):
conn_str = "Server=test;Database=testdb;UID=user;PWD=password"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert "UID=user" in result_str
assert "PWD=password" in result_str
assert attrs is None
assert auth_type is None

def test_process_connection_string_interactive_non_windows(self, monkeypatch):
monkeypatch.setattr(platform, "system", lambda: "Darwin")
conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb"
result_str, attrs = process_connection_string(conn_str)
result_str, attrs, auth_type = process_connection_string(conn_str)

assert "Server=test" in result_str
assert "Database=testdb" in result_str
assert attrs is not None
assert 1256 in attrs
assert isinstance(attrs[1256], bytes)
assert auth_type == "interactive"


def test_error_handling():
Expand All @@ -368,3 +378,42 @@ def test_error_handling():
# Test non-string input
with pytest.raises(ValueError, match="Connection string must be a string"):
process_connection_string(None)


class TestExtractAuthType:
def test_interactive(self):
assert (
extract_auth_type("Server=test;Authentication=ActiveDirectoryInteractive;")
== "interactive"
)

def test_default(self):
assert extract_auth_type("Server=test;Authentication=ActiveDirectoryDefault;") == "default"

def test_devicecode(self):
assert (
extract_auth_type("Server=test;Authentication=ActiveDirectoryDeviceCode;")
== "devicecode"
)

def test_no_auth(self):
assert extract_auth_type("Server=test;Database=db;") is None

def test_unsupported_auth(self):
assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None


def test_acquire_token_unsupported_auth_type():
with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"):
AADAuth._acquire_token("bogus")


class TestConnectionAuthType:
@patch("mssql_python.connection.ddbc_bindings.Connection")
def test_auth_type_stored_on_connection(self, mock_ddbc_conn):
mock_ddbc_conn.return_value = MagicMock()
from mssql_python import connect

conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault")
assert conn._auth_type == "default"
conn.close()
Loading