//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
//

package org.apache.cloudstack.saml;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import javax.xml.XMLConstants;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.FactoryConfigurationError;

import org.apache.cloudstack.api.ApiConstants;
import org.apache.cloudstack.api.response.LoginCmdResponse;
import org.apache.cloudstack.utils.security.CertUtils;
import org.apache.cloudstack.utils.security.ParserUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Logger;
import org.bouncycastle.operator.OperatorCreationException;
import org.joda.time.DateTime;
import org.opensaml.Configuration;
import org.opensaml.DefaultBootstrap;
import org.opensaml.common.SAMLVersion;
import org.opensaml.common.xml.SAMLConstants;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.AttributeStatement;
import org.opensaml.saml2.core.AuthnContext;
import org.opensaml.saml2.core.AuthnContextClassRef;
import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.Issuer;
import org.opensaml.saml2.core.LogoutRequest;
import org.opensaml.saml2.core.NameID;
import org.opensaml.saml2.core.RequestedAuthnContext;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.impl.AuthnContextClassRefBuilder;
import org.opensaml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml2.core.impl.LogoutRequestBuilder;
import org.opensaml.saml2.core.impl.NameIDBuilder;
import org.opensaml.saml2.core.impl.RequestedAuthnContextBuilder;
import org.opensaml.xml.ConfigurationException;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.io.Marshaller;
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.io.Unmarshaller;
import org.opensaml.xml.io.UnmarshallerFactory;
import org.opensaml.xml.io.UnmarshallingException;
import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.signature.SignatureConstants;
import org.opensaml.xml.util.Base64;
import org.opensaml.xml.util.XMLHelper;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.xml.sax.SAXException;

import com.cloud.api.ApiServlet;
import com.cloud.utils.HttpUtils;
import com.cloud.utils.exception.CloudRuntimeException;

public class SAMLUtils {
    public static final Logger s_logger = Logger.getLogger(SAMLUtils.class);

    static final String charset = "abcdefghijklmnopqrstuvwxyz";

    public static String generateSecureRandomId() {
        return new BigInteger(160, new SecureRandom()).toString(32).replaceFirst("^[0-9]",
                String.valueOf(charset.charAt(new SecureRandom().nextInt(charset.length()))));
    }

    public static String getValueFromAttributeStatements(final List<AttributeStatement> attributeStatements, final String attributeKey) {
        if (attributeStatements == null || attributeStatements.size() < 1 || attributeKey == null) {
            return null;
        }
        for (AttributeStatement attributeStatement : attributeStatements) {
            if (attributeStatement == null || attributeStatements.size() < 1) {
                continue;
            }
            for (Attribute attribute : attributeStatement.getAttributes()) {
                if (attribute.getAttributeValues() != null && attribute.getAttributeValues().size() > 0) {
                    String value = attribute.getAttributeValues().get(0).getDOM().getTextContent();
                    s_logger.debug("SAML attribute name: " + attribute.getName() + " friendly-name:" + attribute.getFriendlyName() + " value:" + value);
                    if (attributeKey.equals(attribute.getName()) || attributeKey.equals(attribute.getFriendlyName())) {
                        return value;
                    }
                }
            }
        }
        return null;
    }

    public static String getValueFromAssertions(final List<Assertion> assertions, final String attributeKey) {
        if (assertions == null || attributeKey == null) {
            return null;
        }
        for (Assertion assertion : assertions) {
            String value = getValueFromAttributeStatements(assertion.getAttributeStatements(), attributeKey);
            if (value != null) {
                return value;
            }
        }
        return null;
    }

    public static String buildAuthnRequestUrl(final String authnId, final SAMLProviderMetadata spMetadata, final SAMLProviderMetadata idpMetadata, final String signatureAlgorithm) {
        String redirectUrl = "";
        try {
            DefaultBootstrap.bootstrap();
            AuthnRequest authnRequest = SAMLUtils.buildAuthnRequestObject(authnId, spMetadata.getEntityId(), idpMetadata.getSsoUrl(), spMetadata.getSsoUrl());
            PrivateKey privateKey = null;
            if (spMetadata.getKeyPair() != null) {
                privateKey = spMetadata.getKeyPair().getPrivate();
            }
            String appendOperator = idpMetadata.getSsoUrl().contains("?") ? "&" : "?";
            redirectUrl = idpMetadata.getSsoUrl() + appendOperator + SAMLUtils.generateSAMLRequestSignature("SAMLRequest=" + SAMLUtils.encodeSAMLRequest(authnRequest), privateKey, signatureAlgorithm);
        } catch (ConfigurationException | FactoryConfigurationError | MarshallingException | IOException | NoSuchAlgorithmException | InvalidKeyException | java.security.SignatureException e) {
            s_logger.error("SAML AuthnRequest message building error: " + e.getMessage());
        }
        return redirectUrl;
    }

    public static AuthnRequest buildAuthnRequestObject(final String authnId, final String spId, final String idpUrl, final String consumerUrl) {
        // Issuer object
        IssuerBuilder issuerBuilder = new IssuerBuilder();
        Issuer issuer = issuerBuilder.buildObject();
        issuer.setValue(spId);

        // AuthnContextClass
        AuthnContextClassRefBuilder authnContextClassRefBuilder = new AuthnContextClassRefBuilder();
        AuthnContextClassRef authnContextClassRef = authnContextClassRefBuilder.buildObject(
                SAMLConstants.SAML20_NS,
                "AuthnContextClassRef", "saml");
        authnContextClassRef.setAuthnContextClassRef(AuthnContext.PPT_AUTHN_CTX);

        // AuthnContext
        RequestedAuthnContextBuilder requestedAuthnContextBuilder = new RequestedAuthnContextBuilder();
        RequestedAuthnContext requestedAuthnContext = requestedAuthnContextBuilder.buildObject();
        requestedAuthnContext.setComparison(AuthnContextComparisonTypeEnumeration.EXACT);
        requestedAuthnContext.getAuthnContextClassRefs().add(authnContextClassRef);

        // Creation of AuthRequestObject
        AuthnRequestBuilder authRequestBuilder = new AuthnRequestBuilder();
        AuthnRequest authnRequest = authRequestBuilder.buildObject();
        authnRequest.setID(authnId);
        authnRequest.setDestination(idpUrl);
        authnRequest.setVersion(SAMLVersion.VERSION_20);
        authnRequest.setForceAuthn(false);
        authnRequest.setIsPassive(false);
        authnRequest.setIssueInstant(new DateTime());
        authnRequest.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI);
        authnRequest.setAssertionConsumerServiceURL(consumerUrl);
        authnRequest.setProviderName(spId);
        authnRequest.setIssuer(issuer);
        authnRequest.setRequestedAuthnContext(requestedAuthnContext);

        return authnRequest;
    }

    public static LogoutRequest buildLogoutRequest(String logoutUrl, String spId, String nameIdString) {
        Issuer issuer = new IssuerBuilder().buildObject();
        issuer.setValue(spId);
        NameID nameID = new NameIDBuilder().buildObject();
        nameID.setValue(nameIdString);
        LogoutRequest logoutRequest = new LogoutRequestBuilder().buildObject();
        logoutRequest.setID(generateSecureRandomId());
        logoutRequest.setDestination(logoutUrl);
        logoutRequest.setVersion(SAMLVersion.VERSION_20);
        logoutRequest.setIssueInstant(new DateTime());
        logoutRequest.setIssuer(issuer);
        logoutRequest.setNameID(nameID);
        return logoutRequest;
    }

    public static String encodeSAMLRequest(XMLObject authnRequest)
            throws MarshallingException, IOException {
        Marshaller marshaller = Configuration.getMarshallerFactory()
                .getMarshaller(authnRequest);
        Element authDOM = marshaller.marshall(authnRequest);
        StringWriter requestWriter = new StringWriter();
        XMLHelper.writeNode(authDOM, requestWriter);
        String requestMessage = requestWriter.toString();
        Deflater deflater = new Deflater(Deflater.DEFLATED, true);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream, deflater);
        deflaterOutputStream.write(requestMessage.getBytes(Charset.forName("UTF-8")));
        deflaterOutputStream.close();
        String encodedRequestMessage = Base64.encodeBytes(byteArrayOutputStream.toByteArray(), Base64.DONT_BREAK_LINES);
        encodedRequestMessage = URLEncoder.encode(encodedRequestMessage, HttpUtils.UTF_8).trim();
        return encodedRequestMessage;
    }

    public static Response decodeSAMLResponse(String responseMessage)
            throws ConfigurationException, ParserConfigurationException,
            SAXException, IOException, UnmarshallingException {
        DocumentBuilderFactory documentBuilderFactory = ParserUtils.getSaferDocumentBuilderFactory();
        documentBuilderFactory.setNamespaceAware(true);
        DocumentBuilder docBuilder = documentBuilderFactory.newDocumentBuilder();
        byte[] base64DecodedResponse = Base64.decode(responseMessage);
        Document document = docBuilder.parse(new ByteArrayInputStream(base64DecodedResponse));
        Element element = document.getDocumentElement();
        UnmarshallerFactory unmarshallerFactory = Configuration.getUnmarshallerFactory();
        Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(element);
        return (Response) unmarshaller.unmarshall(element);
    }

    public static String generateSAMLRequestSignature(final String urlEncodedString, final PrivateKey signingKey, final String sigAlgorithmName)
            throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, UnsupportedEncodingException {
        if (signingKey == null) {
            return urlEncodedString;
        }

        String opensamlAlgoIdSignature = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1;
        String javaSignatureAlgorithmName = "SHA1withRSA";

        if (sigAlgorithmName.equalsIgnoreCase("SHA256")) {
            opensamlAlgoIdSignature = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
            javaSignatureAlgorithmName = "SHA256withRSA";
        } else if (sigAlgorithmName.equalsIgnoreCase("SHA384")) {
            opensamlAlgoIdSignature = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA384;
            javaSignatureAlgorithmName = "SHA384withRSA";
        } else if (sigAlgorithmName.equalsIgnoreCase("SHA512")) {
            opensamlAlgoIdSignature = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA512;
            javaSignatureAlgorithmName = "SHA512withRSA";
        }

        String url = urlEncodedString + "&SigAlg=" + URLEncoder.encode(opensamlAlgoIdSignature, HttpUtils.UTF_8);
        Signature signature = Signature.getInstance(javaSignatureAlgorithmName);
        signature.initSign(signingKey);
        signature.update(url.getBytes(Charset.forName("UTF-8")));
        String signatureString = Base64.encodeBytes(signature.sign(), Base64.DONT_BREAK_LINES);
        if (signatureString != null) {
            return url + "&Signature=" + URLEncoder.encode(signatureString, HttpUtils.UTF_8);
        }
        return url;
    }

    public static void setupSamlUserCookies(final LoginCmdResponse loginResponse, final HttpServletResponse resp) throws IOException {
        resp.addCookie(new Cookie("userid", URLEncoder.encode(loginResponse.getUserId(), HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("domainid", URLEncoder.encode(loginResponse.getDomainId(), HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("role", URLEncoder.encode(loginResponse.getType(), HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("username", URLEncoder.encode(loginResponse.getUsername(), HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("account", URLEncoder.encode(loginResponse.getAccount(), HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("isSAML", URLEncoder.encode("true", HttpUtils.UTF_8)));
        resp.addCookie(new Cookie("twoFaEnabled", URLEncoder.encode(loginResponse.is2FAenabled(), HttpUtils.UTF_8)));
        String providerFor2FA = loginResponse.getProviderFor2FA();
        if (StringUtils.isNotEmpty(providerFor2FA)) {
            resp.addCookie(new Cookie("twoFaProvider", URLEncoder.encode(loginResponse.getProviderFor2FA(), HttpUtils.UTF_8)));
        }
        String timezone = loginResponse.getTimeZone();
        if (timezone != null) {
            resp.addCookie(new Cookie("timezone", URLEncoder.encode(timezone, HttpUtils.UTF_8)));
        }
        resp.addCookie(new Cookie("userfullname", URLEncoder.encode(loginResponse.getFirstName() + " " + loginResponse.getLastName(), HttpUtils.UTF_8).replace("+", "%20")));

        String redirectUrl = SAML2AuthManager.SAMLCloudStackRedirectionUrl.value();
        String path = SAML2AuthManager.SAMLUserSessionKeyPathAttribute.value();
        String domain = null;
        try {
            URI redirectUri = new URI(redirectUrl);
            domain = redirectUri.getHost();
            if (StringUtils.isBlank(path)) {
                path = redirectUri.getPath();
            }
            if (StringUtils.isBlank(path)) {
                path = "/";
            }
        } catch (URISyntaxException ex) {
            throw new CloudRuntimeException("Invalid URI: " + redirectUrl);
        }
        String sameSite = ApiServlet.getApiSessionKeySameSite();
        String sessionKeyCookie = String.format("%s=%s;Domain=%s;Path=%s;%s", ApiConstants.SESSIONKEY, loginResponse.getSessionKey(), domain, path, sameSite);
        s_logger.debug("Adding sessionkey cookie to response: " + sessionKeyCookie);
        resp.addHeader("SET-COOKIE", sessionKeyCookie);
    }

    /**
     * Returns base64 encoded PublicKey
     * @param key PublicKey
     * @return public key encoded string
     */
    public static String encodePublicKey(PublicKey key) {
        try {
            KeyFactory keyFactory = CertUtils.getKeyFactory();
            if (keyFactory == null) return null;
            X509EncodedKeySpec spec = keyFactory.getKeySpec(key, X509EncodedKeySpec.class);
            return new String(org.bouncycastle.util.encoders.Base64.encode(spec.getEncoded()), Charset.forName("UTF-8"));
        } catch (InvalidKeySpecException e) {
            s_logger.error("Unable to get KeyFactory:" + e.getMessage());
        }
        return null;
    }

    /**
     * Returns base64 encoded PrivateKey
     * @param key PrivateKey
     * @return privatekey encoded string
     */
    public static String encodePrivateKey(PrivateKey key) {
        try {
            KeyFactory keyFactory = CertUtils.getKeyFactory();
            if (keyFactory == null) return null;
            PKCS8EncodedKeySpec spec = keyFactory.getKeySpec(key,
                    PKCS8EncodedKeySpec.class);
            return new String(org.bouncycastle.util.encoders.Base64.encode(spec.getEncoded()), Charset.forName("UTF-8"));
        } catch (InvalidKeySpecException e) {
            s_logger.error("Unable to get KeyFactory:" + e.getMessage());
        }
        return null;
    }

    /**
     * Decodes base64 encoded public key to PublicKey
     * @param publicKey encoded public key string
     * @return returns PublicKey
     */
    public static PublicKey decodePublicKey(String publicKey) {
        byte[] sigBytes = org.bouncycastle.util.encoders.Base64.decode(publicKey);
        X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(sigBytes);
        KeyFactory keyFactory = CertUtils.getKeyFactory();
        if (keyFactory == null)
            return null;
        try {
            return keyFactory.generatePublic(x509KeySpec);
        } catch (InvalidKeySpecException e) {
            s_logger.error("Unable to create PublicKey from PublicKey string:" + e.getMessage());
        }
        return null;
    }

    /**
     * Decodes base64 encoded private key to PrivateKey
     * @param privateKey encoded private key string
     * @return returns PrivateKey
     */
    public static PrivateKey decodePrivateKey(String privateKey) {
        byte[] sigBytes = org.bouncycastle.util.encoders.Base64.decode(privateKey);
        PKCS8EncodedKeySpec pkscs8KeySpec = new PKCS8EncodedKeySpec(sigBytes);
        KeyFactory keyFactory = CertUtils.getKeyFactory();
        if (keyFactory == null)
            return null;
        try {
            return keyFactory.generatePrivate(pkscs8KeySpec);
        } catch (InvalidKeySpecException e) {
            s_logger.error("Unable to create PrivateKey from privateKey string:" + e.getMessage());
        }
        return null;
    }

    public static X509Certificate generateRandomX509Certificate(KeyPair keyPair) throws NoSuchAlgorithmException, NoSuchProviderException, CertificateException, SignatureException, InvalidKeyException, OperatorCreationException {
        return CertUtils.generateV1Certificate(keyPair,
                "CN=ApacheCloudStack", "CN=ApacheCloudStack",
                3, "SHA256WithRSA");
    }

    public static BasicParserPool getSaferParserPool() {
        final Map<String, Boolean> features = new HashMap<>();
        features.put(XMLConstants.FEATURE_SECURE_PROCESSING, true);
        features.put("http://apache.org/xml/features/disallow-doctype-decl", true);
        features.put("http://xml.org/sax/features/external-general-entities", false);
        features.put("http://xml.org/sax/features/external-parameter-entities", false);
        features.put("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
        final BasicParserPool parserPool = new BasicParserPool();
        parserPool.setXincludeAware(false);
        parserPool.setIgnoreComments(true);
        parserPool.setExpandEntityReferences(false);
        parserPool.setBuilderFeatures(features);
        return parserPool;
    }
}
