From 589110961a270733e4e3fb6add75117be124d984 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 5 Mar 2025 14:02:19 +0200
Subject: [PATCH] Isolate redis-entraid dependency for tests (#3521)

Index: tests/test_asyncio/conftest.py
--- tests/test_asyncio/conftest.py.orig
+++ tests/test_asyncio/conftest.py
@@ -1,7 +1,5 @@
-import os
 import random
 from contextlib import asynccontextmanager as _asynccontextmanager
-from datetime import datetime, timezone
 from enum import Enum
 from typing import Union
 
@@ -9,34 +7,14 @@ import jwt
 import pytest
 import pytest_asyncio
 import redis.asyncio as redis
-from mock.mock import Mock
 from packaging.version import Version
 from redis.asyncio import Sentinel
 from redis.asyncio.client import Monitor
 from redis.asyncio.connection import Connection, parse_url
 from redis.asyncio.retry import Retry
-from redis.auth.idp import IdentityProviderInterface
-from redis.auth.token import JWToken
-from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
 from redis.backoff import NoBackoff
 from redis.credentials import CredentialProvider
-from redis_entraid.cred_provider import (
-    DEFAULT_DELAY_IN_MS,
-    DEFAULT_EXPIRATION_REFRESH_RATIO,
-    DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
-    DEFAULT_MAX_ATTEMPTS,
-    DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
-    EntraIdCredentialsProvider,
-)
-from redis_entraid.identity_provider import (
-    ManagedIdentityIdType,
-    ManagedIdentityProviderConfig,
-    ManagedIdentityType,
-    ServicePrincipalIdentityProviderConfig,
-    _create_provider_from_managed_identity,
-    _create_provider_from_service_principal,
-)
-from tests.conftest import REDIS_INFO
+from tests.conftest import REDIS_INFO, get_credential_provider
 
 from .compat import mock
 
@@ -244,135 +222,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwa
     )
     for mocked in _gen_cluster_mock_resp(r, response):
         yield mocked
-
-
-def mock_identity_provider() -> IdentityProviderInterface:
-    mock_provider = Mock(spec=IdentityProviderInterface)
-    token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"}
-    encoded = jwt.encode(token, "secret", algorithm="HS256")
-    jwt_token = JWToken(encoded)
-    mock_provider.request_token.return_value = jwt_token
-    return mock_provider
-
-
-def identity_provider(request) -> IdentityProviderInterface:
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    if request.param.get("mock_idp", None) is not None:
-        return mock_identity_provider()
-
-    auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
-    config = get_identity_provider_config(request=request)
-
-    if auth_type == "MANAGED_IDENTITY":
-        return _create_provider_from_managed_identity(config)
-
-    return _create_provider_from_service_principal(config)
-
-
-def get_identity_provider_config(
-    request,
-) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
-
-    if auth_type == AuthType.MANAGED_IDENTITY:
-        return _get_managed_identity_provider_config(request)
-
-    return _get_service_principal_provider_config(request)
-
-
-def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
-    resource = os.getenv("AZURE_RESOURCE")
-    id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
-
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-    else:
-        kwargs = {}
-
-    identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
-    id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
-
-    return ManagedIdentityProviderConfig(
-        identity_type=identity_type,
-        resource=resource,
-        id_type=id_type,
-        id_value=id_value,
-        kwargs=kwargs,
-    )
-
-
-def _get_service_principal_provider_config(
-    request,
-) -> ServicePrincipalIdentityProviderConfig:
-    client_id = os.getenv("AZURE_CLIENT_ID")
-    client_credential = os.getenv("AZURE_CLIENT_SECRET")
-    tenant_id = os.getenv("AZURE_TENANT_ID")
-    scopes = os.getenv("AZURE_REDIS_SCOPES", None)
-
-    if hasattr(request, "param"):
-        kwargs = request.param.get("idp_kwargs", {})
-        token_kwargs = request.param.get("token_kwargs", {})
-        timeout = request.param.get("timeout", None)
-    else:
-        kwargs = {}
-        token_kwargs = {}
-        timeout = None
-
-    if isinstance(scopes, str):
-        scopes = scopes.split(",")
-
-    return ServicePrincipalIdentityProviderConfig(
-        client_id=client_id,
-        client_credential=client_credential,
-        scopes=scopes,
-        timeout=timeout,
-        token_kwargs=token_kwargs,
-        tenant_id=tenant_id,
-        app_kwargs=kwargs,
-    )
-
-
-def get_credential_provider(request) -> CredentialProvider:
-    cred_provider_class = request.param.get("cred_provider_class")
-    cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
-
-    if cred_provider_class != EntraIdCredentialsProvider:
-        return cred_provider_class(**cred_provider_kwargs)
-
-    idp = identity_provider(request)
-    expiration_refresh_ratio = cred_provider_kwargs.get(
-        "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
-    )
-    lower_refresh_bound_millis = cred_provider_kwargs.get(
-        "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
-    )
-    max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
-    delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
-
-    token_mgr_config = TokenManagerConfig(
-        expiration_refresh_ratio=expiration_refresh_ratio,
-        lower_refresh_bound_millis=lower_refresh_bound_millis,
-        token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,  # noqa
-        retry_policy=RetryPolicy(
-            max_attempts=max_attempts,
-            delay_in_ms=delay_in_ms,
-        ),
-    )
-
-    return EntraIdCredentialsProvider(
-        identity_provider=idp,
-        token_manager_config=token_mgr_config,
-        initial_delay_in_ms=delay_in_ms,
-    )
 
 
 @pytest_asyncio.fixture()
