/*
 * The OpenSAML License, Version 1.
 * Copyright (c) 2002
 * University Corporation for Advanced Internet Development, Inc.
 * All rights reserved
 *
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution, if any, must include
 * the following acknowledgment: "This product includes software developed by
 * the University Corporation for Advanced Internet Development
 * <http://www.ucaid.edu>Internet2 Project. Alternately, this acknowledegement
 * may appear in the software itself, if and wherever such third-party
 * acknowledgments normally appear.
 *
 * Neither the name of OpenSAML nor the names of its contributors, nor
 * Internet2, nor the University Corporation for Advanced Internet Development,
 * Inc., nor UCAID may be used to endorse or promote products derived from this
 * software without specific prior written permission. For written permission,
 * please contact opensaml@opensaml.org
 *
 * Products derived from this software may not be called OpenSAML, Internet2,
 * UCAID, or the University Corporation for Advanced Internet Development, nor
 * may OpenSAML appear in their name, without prior written permission of the
 * University Corporation for Advanced Internet Development.
 *
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
 * PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED AND THE ENTIRE RISK
 * OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE.
 * IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY
 * CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC. BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */


/* SAMLPOSTProfile.cpp - implements basics of the SAML POST profile

   Scott Cantor
   8/12/02

   $History:$
*/

#include <xercesc/util/XMLChar.hpp>
#include <xercesc/util/Base64.hpp>

#include "internal.h"

#include <ctime>
#include <sstream>
using namespace saml;
using namespace std;
using namespace log4cpp;

set<xstring> SAMLPOSTProfile::m_replayCache;
multimap<xstring,xstring> SAMLPOSTProfile::m_replayExpMap;

namespace {
const XMLCh OLD_BEARER[] = // urn:oasis:names:tc:SAML:1.0:cm:Bearer
{
    chLatin_u, chLatin_r, chLatin_n, chColon, chLatin_o, chLatin_a, chLatin_s, chLatin_i, chLatin_s, chColon,
    chLatin_n, chLatin_a, chLatin_m, chLatin_e, chLatin_s, chColon, chLatin_t, chLatin_c, chColon,
    chLatin_S, chLatin_A, chLatin_M, chLatin_L, chColon, chDigit_1, chPeriod, chDigit_0, chColon,
    chLatin_c, chLatin_m, chColon, chLatin_B, chLatin_e, chLatin_a, chLatin_r, chLatin_e, chLatin_r, chNull
};
}

const SAMLAssertion* SAMLPOSTProfile::getSSOAssertion(const SAMLResponse& r, const Iterator<const XMLCh*>& audiences)
{
    bool bOldCode=false;
    bool bExpired,bAudience;
    SAMLConfig& config=SAMLConfig::getConfig();

    for (Iterator<SAMLAssertion*> assertions=r.getAssertions(); assertions.hasNext();)
    {
        bExpired=bAudience=false;
        const SAMLAssertion* a=assertions.next();

        // A SSO assertion must be bounded front and back.
        const XMLDateTime* notBefore=a->getNotBefore();
        const XMLDateTime* notOnOrAfter=a->getNotOnOrAfter();
        if (!notBefore || !notOnOrAfter)
        {
            Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").debug("getSSOAssertion() skipping assertion without time conditions...");
            continue;
        }

        time_t now=time(NULL)+config.clock_skew_secs;
#ifndef HAVE_GMTIME_R
        struct tm* ptime=gmtime(&now);
#else
        struct tm res;
        struct tm* ptime=gmtime_r(&now,&res);
#endif
        char timebuf[32];
        strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
        auto_ptr<XMLCh> timeptr(XMLString::transcode(timebuf));
        XMLDateTime before(timeptr.get());
        before.parseDateTime();

        if (XMLDateTime::compareOrder(&before,notBefore)==XMLDateTime::LESS_THAN)
        {
            bExpired=true;
            Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").debug("getSSOAssertion() skipping assertion that's not yet valid...");
            continue;
        }

        now=time(NULL)-config.clock_skew_secs;
#ifndef HAVE_GMTIME_R
        ptime=gmtime(&now);
#else
        ptime=gmtime_r(&now,&res);
#endif
        strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
        auto_ptr<XMLCh> timeptr2(XMLString::transcode(timebuf));
        XMLDateTime after(timeptr2.get());
        after.parseDateTime();

        if (XMLDateTime::compareOrder(notOnOrAfter,&after)!=XMLDateTime::GREATER_THAN)
        {
            bExpired=true;
            Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").debug("getSSOAssertion() skipping assertion that's expired...");
            continue;
        }

        // Check conditions. The only type we know about is an audience restriction.
        bool valid=true;
        for (Iterator<SAMLCondition*> conditions=a->getConditions(); conditions.hasNext();)
        {
            const SAMLCondition* c=conditions.next();
            const SAMLAudienceRestrictionCondition* ac=dynamic_cast<const SAMLAudienceRestrictionCondition*>(c);
            audiences.reset();
            if (!ac || !ac->eval(audiences))
            {
                valid=false;
                bAudience=true;
                break;
            }
        }
        if (!valid)
        {
            Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").debug("getSSOAssertion() skipping assertion with invalid Audiences...");
            continue;
        }

        // Look for an authentication statement.
        for (Iterator<SAMLStatement*> statements=a->getStatements(); statements.hasNext();)
        {
            const SAMLStatement* s=statements.next();
            const SAMLAuthenticationStatement* as=dynamic_cast<const SAMLAuthenticationStatement*>(s);
            if (!as)
                continue;

            const SAMLSubject* subject=as->getSubject();
            for (Iterator<const XMLCh*> methods=subject->getConfirmationMethods(); methods.hasNext();)
            {
                const XMLCh* m=methods.next();
                if (!XMLString::compareString(m,SAMLSubject::CONF_BEARER))
                    return a;
                if (!XMLString::compareString(m,OLD_BEARER))
                    bOldCode=true;
            }
        }
    }
    if (bExpired && r.getAssertions().size()==1)
        throw ExpiredAssertionException(SAMLException::RESPONDER,"SAMLPOSTProfile::getSSOAssertion() unable to start session because of clock skew or replay");
    else if (bAudience && r.getAssertions().size()==1)
    {
        xstring buf;
        audiences.reset();
        while (audiences.hasNext())
        {
            if (!buf.empty())
                buf=buf + chComma + chSpace;
            buf+=audiences.next();
        }
        if (buf.empty())
            throw FatalProfileException(SAMLException::RESPONDER,"SAMLPOSTProfile::getSSOAssertion() unable to start session due to policy mismatch (target policies: none)");
        else
        {
            auto_ptr<char> msg(XMLString::transcode(buf.c_str()));
            throw FatalProfileException(SAMLException::RESPONDER,string("SAMLPOSTProfile::getSSOAssertion() unable to start session due to policy mismatch (target policies: ") + msg.get() + ")");
        }
    }

    if (bOldCode)
        throw FatalProfileException(SAMLException::RESPONDER,"We've detected an attempt to authenticate using an incompatible beta version of Shibboleth. Please inform your identity provider's administrative staff that they should upgrade to a recent release.");

    throw FatalProfileException(SAMLException::RESPONDER,"SAMLPOSTProfile::getSSOAssertion() unable to start session");
}

const SAMLAuthenticationStatement* SAMLPOSTProfile::getSSOStatement(const SAMLAssertion& a)
{
    // Look for an authentication statement.
    for (Iterator<SAMLStatement*> statements=a.getStatements(); statements.hasNext();)
    {
        const SAMLStatement* s=statements.next();
        const SAMLAuthenticationStatement* as=dynamic_cast<const SAMLAuthenticationStatement*>(s);
        if (!as)
            continue;

        const SAMLSubject* subject=as->getSubject();
        for (Iterator<const XMLCh*> methods=subject->getConfirmationMethods(); methods.hasNext();)
        {
            const XMLCh* m=methods.next();
            if (!XMLString::compareString(m,SAMLSubject::CONF_BEARER))
                return as;
        }
    }
    throw FatalProfileException(SAMLException::RESPONDER,"SAMLPOSTProfile::getSSOStatement() unable to find an SSO statement");
}

bool SAMLPOSTProfile::checkReplayCache(const SAMLAssertion& a)
{
    SAMLConfig& config=SAMLConfig::getConfig();

    // Garbage collect any expired entries.
    time_t now=time(NULL)-config.clock_skew_secs;
#ifndef HAVE_GMTIME_R
    struct tm* ptime=gmtime(&now);
#else
    struct tm res;
    struct tm* ptime=gmtime_r(&now,&res);
#endif
    char timebuf[32];
    strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
    auto_ptr<XMLCh> timeptr(XMLString::transcode(timebuf));

    config.saml_lock();

    try
    {
        multimap<xstring,xstring>::iterator stop=m_replayExpMap.upper_bound(timeptr.get());
        for (multimap<xstring,xstring>::iterator i=m_replayExpMap.begin(); i!=stop; m_replayExpMap.erase(i++))
            m_replayCache.erase(i->second);

        // If it's already been seen, bail.
        if (!m_replayCache.insert(a.getId()).second)
        {
            config.saml_unlock();
            return false;
        }

        // Add the pair to the expiration map.
        auto_ptr<XMLCh> expptr(a.getNotOnOrAfter()->toString());
        m_replayExpMap.insert(multimap<xstring,xstring>::value_type(expptr.get(),a.getId()));
    }
    catch(...)
    {
        config.saml_unlock();
        Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").error("checkReplayCache() caught an exception");
        return false;
    }
    config.saml_unlock();
    return true;
}

SAMLResponse* SAMLPOSTProfile::accept(const XMLByte* buf, const XMLCh* receiver, int ttlSeconds, bool process)
{
    unsigned int inlen=0,pos=0;
    while (buf && buf[inlen])
        inlen++;
    
    XMLByte* normalized=new XMLByte[inlen+1];
    for (inlen=0; buf && buf[inlen]; inlen++)
        if (!XMLChar1_0::isWhitespace(buf[inlen]))
            normalized[pos++]=buf[inlen];
    normalized[pos]=0;
    
    unsigned int len;
    auto_ptr<XMLByte> decoded(Base64::decode(normalized,&len));
    delete[] normalized;
    if (!decoded.get())
        throw SAMLException(SAMLException::RESPONDER,"SAMLPOSTProfile::accept() unable to decode base64 data");

    Category::getInstance(SAML_LOGCAT".SAMLPOSTProfile").debug("accept: decoded assertion:\n%s",decoded.get());

    stringstream str(reinterpret_cast<char*>(decoded.get()));

    auto_ptr<SAMLResponse> r(new SAMLResponse(str));
    if (process)
        SAMLPOSTProfile::process(*r,receiver,ttlSeconds);
    return r.release();
}

void SAMLPOSTProfile::process(SAMLResponse& r, const XMLCh* receiver, int ttlSeconds)
{
    const XMLCh* recipient=r.getRecipient();
    if (!receiver || !*receiver || !recipient || !*recipient || XMLString::compareString(receiver,recipient))
        throw InvalidAssertionException(SAMLException::REQUESTER, "SAMLPOSTProfile::process() detected recipient mismatch");

    SAMLConfig& config=SAMLConfig::getConfig();

    time_t now=time(NULL)-ttlSeconds-config.clock_skew_secs;
#ifndef HAVE_GMTIME_R
    struct tm* ptime=gmtime(&now);
#else
    struct tm res;
    struct tm* ptime=gmtime_r(&now,&res);
#endif
    char timebuf[32];
    strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
    auto_ptr<XMLCh> timeptr(XMLString::transcode(timebuf));
    XMLDateTime before(timeptr.get());
    before.parseDateTime();

    if (XMLDateTime::compareOrder(r.getIssueInstant(),&before)==XMLDateTime::LESS_THAN)
        throw ExpiredAssertionException(SAMLException::RESPONDER, "SAMLPOSTProfile::process() detected expired response");
}

SAMLResponse* SAMLPOSTProfile::prepare(const XMLCh* recipient,
                                       const XMLCh* issuer,
                                       const Iterator<const XMLCh*>& audiences,
                                       const XMLCh* name,
                                       const XMLCh* nameQualifier,
                                       const XMLCh* format,
                                       const XMLCh* subjectIP,
                                       const XMLCh* authMethod,
                                       const XMLDateTime& authInstant,
                                       const Iterator<SAMLAuthorityBinding*>& bindings)
{
    if (!recipient || !*recipient)
        throw FatalProfileException(SAMLException::RESPONDER, "SAMLPOSTProfile::prepare() requires recipient");

    time_t now=time(NULL);
#ifndef HAVE_GMTIME_R
    struct tm* ptime=gmtime(&now);
#else
    struct tm res;
    struct tm* ptime=gmtime_r(&now,&res);
#endif
    char timebuf[32];
    strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
    auto_ptr<XMLCh> timeptr(XMLString::transcode(timebuf));
    XMLDateTime lower(timeptr.get());

    now=time(NULL)+SAMLConfig::getConfig().clock_skew_secs;
#ifndef HAVE_GMTIME_R
    ptime=gmtime(&now);
#else
    ptime=gmtime_r(&now,&res);
#endif
    strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
    auto_ptr<XMLCh> timeptr2(XMLString::transcode(timebuf));
    XMLDateTime upper(timeptr2.get());

    const XMLCh* confirmationMethods[] = {SAMLSubject::CONF_BEARER};

    auto_ptr<SAMLSubject> subject(new SAMLSubject(name, nameQualifier, format, ArrayIterator<const XMLCh*>(confirmationMethods)));
    auto_ptr<SAMLAuthenticationStatement>
        statement(new SAMLAuthenticationStatement(subject.get(), authMethod, authInstant, subjectIP, NULL, bindings));
    subject.release();
    
    auto_ptr<SAMLAudienceRestrictionCondition> condition;
    if (audiences.hasNext())
    {
        auto_ptr<SAMLAudienceRestrictionCondition> condcopy (new SAMLAudienceRestrictionCondition(audiences));
        condition=condcopy;
    }

    SAMLCondition* conditions[] = { condition.get() };
    SAMLStatement* statements[] = { statement.get() };
    
    auto_ptr<SAMLAssertion> assertion(new SAMLAssertion(issuer, &lower, &upper,
                                      condition.get() ? ArrayIterator<SAMLCondition*>(conditions) : Iterator<SAMLCondition*>(),
                                      ArrayIterator<SAMLStatement*>(statements)));
    condition.release();
    statement.release();

    SAMLAssertion* assertions[] = { assertion.get() };

    SAMLResponse* response=new SAMLResponse(NULL, recipient, ArrayIterator<SAMLAssertion*>(assertions));
    assertion.release();

    return response;
}
