/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pulsar.broker.authentication.oidc;

import com.auth0.jwk.InvalidPublicKeyException;
import com.auth0.jwk.Jwk;
import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.AlgorithmMismatchException;
import com.auth0.jwt.exceptions.InvalidClaimException;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.exceptions.SignatureVerificationException;
import com.auth0.jwt.exceptions.TokenExpiredException;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import io.kubernetes.client.openapi.ApiClient;
import io.kubernetes.client.util.Config;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.net.SocketAddress;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import org.apache.commons.lang3.StringUtils;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationProviderToken;
import org.apache.pulsar.broker.authentication.AuthenticationState;
import org.apache.pulsar.broker.authentication.metrics.AuthenticationMetrics;
import org.apache.pulsar.broker.authentication.oidc.AuthenticationExceptionCode;
import org.apache.pulsar.broker.authentication.oidc.AuthenticationStateOpenID;
import org.apache.pulsar.broker.authentication.oidc.ConfigUtils;
import org.apache.pulsar.broker.authentication.oidc.FallbackDiscoveryMode;
import org.apache.pulsar.broker.authentication.oidc.JwksCache;
import org.apache.pulsar.broker.authentication.oidc.OpenIDProviderMetadataCache;
import org.apache.pulsar.common.api.AuthData;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.AsyncHttpClientConfig;
import org.asynchttpclient.DefaultAsyncHttpClient;
import org.asynchttpclient.DefaultAsyncHttpClientConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AuthenticationProviderOpenID
implements AuthenticationProvider {
    private static final Logger log = LoggerFactory.getLogger(AuthenticationProviderOpenID.class);
    private static final String SIMPLE_NAME = AuthenticationProviderOpenID.class.getSimpleName();
    private static final String AUTH_METHOD_NAME = "token";
    private final JWT jwtLibrary = new JWT();
    private Set<String> issuers;
    private OpenIDProviderMetadataCache openIDProviderMetadataCache;
    private JwksCache jwksCache;
    private volatile AsyncHttpClient httpClient;
    private static final String ALG_RS256 = "RS256";
    private static final String ALG_RS384 = "RS384";
    private static final String ALG_RS512 = "RS512";
    private static final String ALG_ES256 = "ES256";
    private static final String ALG_ES384 = "ES384";
    private static final String ALG_ES512 = "ES512";
    private long acceptedTimeLeewaySeconds;
    private FallbackDiscoveryMode fallbackDiscoveryMode;
    private String roleClaim = "sub";
    private boolean isRoleClaimNotSubject;
    static final String ALLOWED_TOKEN_ISSUERS = "openIDAllowedTokenIssuers";
    static final String ISSUER_TRUST_CERTS_FILE_PATH = "openIDTokenIssuerTrustCertsFilePath";
    static final String FALLBACK_DISCOVERY_MODE = "openIDFallbackDiscoveryMode";
    static final String ALLOWED_AUDIENCES = "openIDAllowedAudiences";
    static final String ROLE_CLAIM = "openIDRoleClaim";
    static final String ROLE_CLAIM_DEFAULT = "sub";
    static final String ACCEPTED_TIME_LEEWAY_SECONDS = "openIDAcceptedTimeLeewaySeconds";
    static final int ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT = 0;
    static final String CACHE_SIZE = "openIDCacheSize";
    static final int CACHE_SIZE_DEFAULT = 5;
    static final String CACHE_REFRESH_AFTER_WRITE_SECONDS = "openIDCacheRefreshAfterWriteSeconds";
    static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 64800;
    static final String CACHE_EXPIRATION_SECONDS = "openIDCacheExpirationSeconds";
    static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 86400;
    static final String KEY_ID_CACHE_MISS_REFRESH_SECONDS = "openIDKeyIdCacheMissRefreshSeconds";
    static final int KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT = 300;
    static final String HTTP_CONNECTION_TIMEOUT_MILLIS = "openIDHttpConnectionTimeoutMillis";
    static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10000;
    static final String HTTP_READ_TIMEOUT_MILLIS = "openIDHttpReadTimeoutMillis";
    static final int HTTP_READ_TIMEOUT_MILLIS_DEFAULT = 10000;
    static final String REQUIRE_HTTPS = "openIDRequireIssuersUseHttps";
    static final boolean REQUIRE_HTTPS_DEFAULT = true;
    private String[] allowedAudiences;

    public void initialize(ServiceConfiguration config) throws IOException {
        this.allowedAudiences = this.validateAllowedAudiences(ConfigUtils.getConfigValueAsSet(config, ALLOWED_AUDIENCES));
        this.roleClaim = ConfigUtils.getConfigValueAsString(config, ROLE_CLAIM, ROLE_CLAIM_DEFAULT);
        this.isRoleClaimNotSubject = !ROLE_CLAIM_DEFAULT.equals(this.roleClaim);
        this.acceptedTimeLeewaySeconds = ConfigUtils.getConfigValueAsInt(config, ACCEPTED_TIME_LEEWAY_SECONDS, 0);
        boolean requireHttps = ConfigUtils.getConfigValueAsBoolean(config, REQUIRE_HTTPS, true);
        this.fallbackDiscoveryMode = FallbackDiscoveryMode.valueOf(ConfigUtils.getConfigValueAsString(config, FALLBACK_DISCOVERY_MODE, FallbackDiscoveryMode.DISABLED.name()));
        this.issuers = this.validateIssuers(ConfigUtils.getConfigValueAsSet(config, ALLOWED_TOKEN_ISSUERS), requireHttps, this.fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED);
        int connectionTimeout = ConfigUtils.getConfigValueAsInt(config, HTTP_CONNECTION_TIMEOUT_MILLIS, 10000);
        int readTimeout = ConfigUtils.getConfigValueAsInt(config, HTTP_READ_TIMEOUT_MILLIS, 10000);
        String trustCertsFilePath = ConfigUtils.getConfigValueAsString(config, ISSUER_TRUST_CERTS_FILE_PATH, null);
        SslContext sslContext = null;
        if (StringUtils.isNotBlank((CharSequence)trustCertsFilePath)) {
            sslContext = SslContextBuilder.forClient().trustManager(new File(trustCertsFilePath)).build();
        }
        DefaultAsyncHttpClientConfig clientConfig = new DefaultAsyncHttpClientConfig.Builder().setCookieStore(null).setConnectTimeout(connectionTimeout).setReadTimeout(readTimeout).setSslContext(sslContext).build();
        this.httpClient = new DefaultAsyncHttpClient((AsyncHttpClientConfig)clientConfig);
        ApiClient k8sApiClient = this.fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED ? Config.defaultClient() : null;
        this.openIDProviderMetadataCache = new OpenIDProviderMetadataCache(config, this.httpClient, k8sApiClient);
        this.jwksCache = new JwksCache(config, this.httpClient, k8sApiClient);
    }

    public String getAuthMethodName() {
        return AUTH_METHOD_NAME;
    }

    public CompletableFuture<String> authenticateAsync(AuthenticationDataSource authData) {
        return this.authenticateTokenAsync(authData).thenApply(this::getRole);
    }

    CompletableFuture<DecodedJWT> authenticateTokenAsync(AuthenticationDataSource authData) {
        String token;
        try {
            token = AuthenticationProviderToken.getToken((AuthenticationDataSource)authData);
        }
        catch (AuthenticationException e2) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(e2);
        }
        return this.authenticateToken(token).whenComplete((jwt, e) -> {
            if (jwt != null) {
                AuthenticationMetrics.authenticateSuccess((String)this.getClass().getSimpleName(), (String)this.getAuthMethodName());
            }
        });
    }

    String getRole(DecodedJWT jwt) {
        try {
            Claim roleClaim = jwt.getClaim(this.roleClaim);
            if (roleClaim.isNull()) {
                return null;
            }
            String role = roleClaim.asString();
            if (role != null) {
                return role;
            }
            List roles = jwt.getClaim(this.roleClaim).asList(String.class);
            if (roles == null || roles.size() == 0) {
                return null;
            }
            if (roles.size() == 1) {
                return (String)roles.get(0);
            }
            log.debug("JWT for subject [{}] has multiple roles; using the first one.", (Object)jwt.getSubject());
            return (String)roles.get(0);
        }
        catch (JWTDecodeException e) {
            log.error("Exception while retrieving role from JWT", (Throwable)e);
            return null;
        }
    }

    DecodedJWT decodeJWT(String token) throws AuthenticationException {
        if (token == null) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            throw new AuthenticationException("Invalid token: cannot be null");
        }
        try {
            return this.jwtLibrary.decodeJwt(token);
        }
        catch (JWTDecodeException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            throw new AuthenticationException("Unable to decode JWT: " + e.getMessage());
        }
    }

    private CompletableFuture<DecodedJWT> authenticateToken(String token) {
        DecodedJWT jwt;
        if (token == null) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(new AuthenticationException("JWT cannot be null"));
        }
        try {
            jwt = this.decodeJWT(token);
        }
        catch (AuthenticationException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(e);
        }
        return this.verifyIssuerAndGetJwk(jwt).thenCompose(jwk -> {
            try {
                if (jwk.getAlgorithm() != null && !jwt.getAlgorithm().equals(jwk.getAlgorithm())) {
                    AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
                    return CompletableFuture.failedFuture(new AuthenticationException("JWK's alg [" + jwk.getAlgorithm() + "] does not match JWT's alg [" + jwt.getAlgorithm() + "]"));
                }
                return CompletableFuture.completedFuture(this.verifyJWT(jwk.getPublicKey(), jwt.getAlgorithm(), jwt));
            }
            catch (InvalidPublicKeyException e) {
                AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.INVALID_PUBLIC_KEY);
                return CompletableFuture.failedFuture(new AuthenticationException("Invalid public key: " + e.getMessage()));
            }
            catch (AuthenticationException e) {
                return CompletableFuture.failedFuture(e);
            }
        });
    }

    private CompletableFuture<Jwk> verifyIssuerAndGetJwk(DecodedJWT jwt) {
        if (jwt.getIssuer() == null) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ISSUER);
            return CompletableFuture.failedFuture(new AuthenticationException("Issuer cannot be null"));
        }
        if (this.issuers.contains(jwt.getIssuer())) {
            return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(jwt.getIssuer()).thenCompose(metadata -> this.jwksCache.getJwk(metadata.getJwksUri(), jwt.getKeyId()));
        }
        if (this.fallbackDiscoveryMode == FallbackDiscoveryMode.KUBERNETES_DISCOVER_TRUSTED_ISSUER) {
            return ((CompletableFuture)this.openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(jwt.getIssuer()).thenCompose(metadata -> this.openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(metadata.getIssuer()))).thenCompose(metadata -> this.jwksCache.getJwk(metadata.getJwksUri(), jwt.getKeyId()));
        }
        if (this.fallbackDiscoveryMode == FallbackDiscoveryMode.KUBERNETES_DISCOVER_PUBLIC_KEYS) {
            return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(jwt.getIssuer()).thenCompose(__ -> this.jwksCache.getJwkFromKubernetesApiServer(jwt.getKeyId()));
        }
        AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ISSUER);
        return CompletableFuture.failedFuture(new AuthenticationException("Issuer not allowed: " + jwt.getIssuer()));
    }

    public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteAddress, SSLSession sslSession) throws AuthenticationException {
        return new AuthenticationStateOpenID(this, remoteAddress, sslSession);
    }

    public void close() throws IOException {
        this.httpClient.close();
    }

    DecodedJWT verifyJWT(PublicKey publicKey, String publicKeyAlg, DecodedJWT jwt) throws AuthenticationException {
        Algorithm alg;
        if (publicKeyAlg == null) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ALGORITHM);
            throw new AuthenticationException("PublicKey algorithm cannot be null");
        }
        try {
            switch (publicKeyAlg) {
                case "RS256": {
                    alg = Algorithm.RSA256((RSAPublicKey)((RSAPublicKey)publicKey), null);
                    break;
                }
                case "RS384": {
                    alg = Algorithm.RSA384((RSAPublicKey)((RSAPublicKey)publicKey), null);
                    break;
                }
                case "RS512": {
                    alg = Algorithm.RSA512((RSAPublicKey)((RSAPublicKey)publicKey), null);
                    break;
                }
                case "ES256": {
                    alg = Algorithm.ECDSA256((ECPublicKey)((ECPublicKey)publicKey), null);
                    break;
                }
                case "ES384": {
                    alg = Algorithm.ECDSA384((ECPublicKey)((ECPublicKey)publicKey), null);
                    break;
                }
                case "ES512": {
                    alg = Algorithm.ECDSA512((ECPublicKey)((ECPublicKey)publicKey), null);
                    break;
                }
                default: {
                    AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ALGORITHM);
                    throw new AuthenticationException("Unsupported algorithm: " + publicKeyAlg);
                }
            }
        }
        catch (ClassCastException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
            throw new AuthenticationException("Expected PublicKey alg [" + publicKeyAlg + "] does match actual alg.");
        }
        Verification verifierBuilder = JWT.require((Algorithm)alg).acceptLeeway(this.acceptedTimeLeewaySeconds).withAnyOfAudience(this.allowedAudiences).withClaimPresence("iat").withClaimPresence("exp").withClaimPresence(ROLE_CLAIM_DEFAULT);
        if (this.isRoleClaimNotSubject) {
            verifierBuilder = verifierBuilder.withClaimPresence(this.roleClaim);
        }
        JWTVerifier verifier = verifierBuilder.build();
        try {
            return verifier.verify(jwt);
        }
        catch (TokenExpiredException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.EXPIRED_JWT);
            throw new AuthenticationException("JWT expired: " + e.getMessage());
        }
        catch (SignatureVerificationException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_VERIFYING_JWT_SIGNATURE);
            throw new AuthenticationException("JWT signature verification exception: " + e.getMessage());
        }
        catch (InvalidClaimException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.INVALID_JWT_CLAIM);
            throw new AuthenticationException("JWT contains invalid claim: " + e.getMessage());
        }
        catch (AlgorithmMismatchException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
            throw new AuthenticationException("JWT algorithm does not match Public Key algorithm: " + e.getMessage());
        }
        catch (JWTDecodeException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            throw new AuthenticationException("Error while decoding JWT: " + e.getMessage());
        }
        catch (JWTVerificationException | IllegalArgumentException e) {
            AuthenticationProviderOpenID.incrementFailureMetric(AuthenticationExceptionCode.ERROR_VERIFYING_JWT);
            throw new AuthenticationException("JWT verification failed: " + e.getMessage());
        }
    }

    static void incrementFailureMetric(AuthenticationExceptionCode code) {
        AuthenticationMetrics.authenticateFailure((String)SIMPLE_NAME, (String)AUTH_METHOD_NAME, (Enum)code);
    }

    private Set<String> validateIssuers(Set<String> allowedIssuers, boolean requireHttps, boolean allowEmptyIssuers) {
        if (allowedIssuers == null || allowedIssuers.isEmpty() && !allowEmptyIssuers) {
            throw new IllegalArgumentException("Missing configured value for: openIDAllowedTokenIssuers");
        }
        for (String issuer : allowedIssuers) {
            if (issuer.toLowerCase().startsWith("https://")) continue;
            log.warn("Allowed issuer is not using https scheme: {}", (Object)issuer);
            if (!requireHttps) continue;
            throw new IllegalArgumentException("Issuer URL does not use https, but must: " + issuer);
        }
        return allowedIssuers;
    }

    String[] validateAllowedAudiences(Set<String> allowedAudiences) {
        if (allowedAudiences == null || allowedAudiences.isEmpty()) {
            throw new IllegalArgumentException("Missing configured value for: openIDAllowedAudiences");
        }
        return allowedAudiences.toArray(new String[0]);
    }
}

