/*
 * The OpenSAML License, Version 1.
 * Copyright (c) 2002
 * University Corporation for Advanced Internet Development, Inc.
 * All rights reserved
 *
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution, if any, must include
 * the following acknowledgment: "This product includes software developed by
 * the University Corporation for Advanced Internet Development
 * <http://www.ucaid.edu>Internet2 Project. Alternately, this acknowledegement
 * may appear in the software itself, if and wherever such third-party
 * acknowledgments normally appear.
 *
 * Neither the name of OpenSAML nor the names of its contributors, nor
 * Internet2, nor the University Corporation for Advanced Internet Development,
 * Inc., nor UCAID may be used to endorse or promote products derived from this
 * software without specific prior written permission. For written permission,
 * please contact opensaml@opensaml.org
 *
 * Products derived from this software may not be called OpenSAML, Internet2,
 * UCAID, or the University Corporation for Advanced Internet Development, nor
 * may OpenSAML appear in their name, without prior written permission of the
 * University Corporation for Advanced Internet Development.
 *
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
 * PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED AND THE ENTIRE RISK
 * OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE.
 * IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY
 * CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC. BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */


/* SAMLConfig.cpp - SAML runtime configuration

   Scott Cantor
   2/20/02

   $History:$
*/

#include "internal.h"

#ifdef HAVE_DLFCN_H
# include <dlfcn.h>
#endif

#include <curl/curl.h>
#include <log4cpp/Category.hh>
#include <log4cpp/PropertyConfigurator.hh>
#include <xercesc/util/PlatformUtils.hpp>
#include <xsec/framework/XSECProvider.hpp>

using namespace saml;
using namespace log4cpp;
using namespace std;

SAML_EXCEPTION_FACTORY(MalformedException);
SAML_EXCEPTION_FACTORY(UnsupportedExtensionException);
SAML_EXCEPTION_FACTORY(InvalidCryptoException);
SAML_EXCEPTION_FACTORY(TrustException);
SAML_EXCEPTION_FACTORY(BindingException);
SAML_EXCEPTION_FACTORY(SOAPException);
SAML_EXCEPTION_FACTORY(ContentTypeException);
SAML_EXCEPTION_FACTORY(UnknownAssertionException);
SAML_EXCEPTION_FACTORY(ProfileException);
SAML_EXCEPTION_FACTORY(FatalProfileException);
SAML_EXCEPTION_FACTORY(RetryableProfileException);
SAML_EXCEPTION_FACTORY(ExpiredAssertionException);
SAML_EXCEPTION_FACTORY(InvalidAssertionException);
SAML_EXCEPTION_FACTORY(ReplayedAssertionException);

extern "C" SAMLQuery* SAMLAttributeQueryFactory(DOMElement* e)
{
    return new SAMLAttributeQuery(e);
}

extern "C" SAMLStatement* SAMLAttributeStatementFactory(DOMElement* e)
{
    return new SAMLAttributeStatement(e);
}

extern "C" SAMLStatement* SAMLAuthenticationStatementFactory(DOMElement* e)
{
    return new SAMLAuthenticationStatement(e);
}

extern "C" SAMLCondition* SAMLAudienceConditionFactory(DOMElement* e)
{
    return new SAMLAudienceRestrictionCondition(e);
}

namespace {
    SAMLInternalConfig g_config;
}

SAMLConfig& SAMLConfig::getConfig()
{
    return g_config;
}

bool SAMLInternalConfig::init()
{
    try
    {
        if (!log_config.empty())
            PropertyConfigurator::configure(log_config);
        m_log=&(Category::getInstance(SAML_LOGCAT".SAMLInternalConfig"));
        saml::NDC ndc("init");
        m_log->debug("library initialization started");

        if (curl_global_init(CURL_GLOBAL_ALL))
        {
            m_log->fatal("init: failed to initialize libcurl, SSL, or Winsock");
            return false;
        }
        m_log->debug("libcurl initialization complete");

        XMLPlatformUtils::Initialize();
        m_log->debug("Xerces initialization complete");

        XSECPlatformUtils::Initialise();
        m_xsec=new XSECProvider();
        m_log->debug("XSEC initialization complete");
        
        if (schema_dir[schema_dir.length()-1]!='/')
            schema_dir+='/';
        auto_ptr<XMLCh> temp(XMLString::transcode(schema_dir.c_str()));
        wide_schema_dir=temp.get();

        if (inclusive_namespace_prefixes.empty())
            inclusive_namespace_prefixes="#default saml samlp ds xsd xsi code kind rw typens";

        auto_ptr<XMLCh> temp2(XMLString::transcode(inclusive_namespace_prefixes.c_str()));
        wide_inclusive_namespace_prefixes=temp2.get();

        m_pool=new XML::ParserPool();
        m_pool->registerSchema(XML::SAML_NS,compatibility_mode ? XML::SAML_SCHEMA_ID : XML::SAML11_SCHEMA_ID);
        m_pool->registerSchema(XML::SAMLP_NS,compatibility_mode ? XML::SAMLP_SCHEMA_ID : XML::SAMLP11_SCHEMA_ID);
        m_pool->registerSchema(XML::SOAP11ENV_NS,XML::SOAP11ENV_SCHEMA_ID);
        m_pool->registerSchema(XML::XMLSIG_NS,XML::XMLSIG_SCHEMA_ID);
        m_pool->registerSchema(XML::XPATH2_NS,XML::XPATH2_SCHEMA_ID);
        m_pool->registerSchema(XML::XML_NS,XML::XML_SCHEMA_ID);
        m_log->debug("SAML schema registration complete");


        m_lock=XMLPlatformUtils::makeMutex();

        // Register built-in SAML type factories.
        saml::QName q1(XML::SAMLP_NS,L(AttributeQueryType));
        saml::QName q2(XML::SAMLP_NS,L(AttributeQuery));
        SAMLQuery::regFactory(q1,&SAMLAttributeQueryFactory);
        SAMLQuery::regFactory(q2,&SAMLAttributeQueryFactory);

        saml::QName s1(XML::SAML_NS,L(AttributeStatementType));
        saml::QName s2(XML::SAML_NS,L(AttributeStatement));
        saml::QName s3(XML::SAML_NS,L(AuthenticationStatementType));
        saml::QName s4(XML::SAML_NS,L(AuthenticationStatement));
        SAMLStatement::regFactory(s1,&SAMLAttributeStatementFactory);
        SAMLStatement::regFactory(s2,&SAMLAttributeStatementFactory);
        SAMLStatement::regFactory(s3,&SAMLAuthenticationStatementFactory);
        SAMLStatement::regFactory(s4,&SAMLAuthenticationStatementFactory);

        saml::QName c1(XML::SAML_NS,L(AudienceRestrictionConditionType));
        saml::QName c2(XML::SAML_NS,L(AudienceRestrictionCondition));
        SAMLCondition::regFactory(c1,&SAMLAudienceConditionFactory);
        SAMLCondition::regFactory(c2,&SAMLAudienceConditionFactory);

        REGISTER_EXCEPTION_FACTORY(org.opensaml,MalformedException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,UnsupportedExtensionException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,InvalidCryptoException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,TrustException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,BindingException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,SOAPException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,ContentTypeException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,UnknownAssertionException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,ProfileException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,FatalProfileException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,RetryableProfileException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,ExpiredAssertionException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,InvalidAssertionException);
        REGISTER_EXCEPTION_FACTORY(org.opensaml,ReplayedAssertionException);

        m_log->debug("SAML type factory registration complete");
    }
    catch(ConfigureFailure& e)
    {
        cerr << "SAMLConfig::init() caught exception while initializing log4cpp: " << e.what() << endl;
        return false;
    }
    catch (const XMLException&)
    {
        saml::NDC ndc("init");
        m_log->fatal("init: caught exception while initializing Xerces");
        curl_global_cleanup();
        return false;
    }

    saml::NDC ndc("init");
    m_log->info("library initialization complete");
    return true;
}

void SAMLInternalConfig::term()
{
    saml::NDC ndc("term");
    for (vector<void*>::reverse_iterator i=m_libhandles.rbegin(); i!=m_libhandles.rend(); i++)
    {
#if defined(WIN32)
        FARPROC fn=GetProcAddress(static_cast<HMODULE>(*i),"saml_extension_term");
        if (fn)
            fn();
        FreeLibrary(static_cast<HMODULE>(*i));
#elif defined(HAVE_DLFCN_H)
        void (*fn)()=(void (*)())dlsym(*i,"saml_extension_term");
        if (fn)
            fn();
        dlclose(*i);
#else
# error "Don't know about dynamic loading on this platform!"
#endif
    }
    
    delete m_xsec;
    XSECPlatformUtils::Terminate();
    XMLPlatformUtils::closeMutex(m_lock);
    delete m_pool;
    XMLPlatformUtils::Terminate();
    curl_global_cleanup();

    m_log->info("library shutdown complete");
}

void SAMLInternalConfig::saml_lock() const
{
    XMLPlatformUtils::lockMutex(m_lock);
}

void SAMLInternalConfig::saml_unlock() const
{
    XMLPlatformUtils::unlockMutex(m_lock);
}

void SAMLInternalConfig::saml_register_extension(const char* path, void* context) const
{
    saml::NDC ndc("saml_register_extension");
    m_log->info("loading extension: %s",path);

#if defined(WIN32)
    HMODULE handle=NULL;
    char* fixed=const_cast<char*>(path);
    if (strchr(fixed,'/'))
    {
        fixed=strdup(path);
        char* p=fixed;
        while (p=strchr(p,'/'))
            *p='\\';
    }

    UINT em=SetErrorMode(SEM_FAILCRITICALERRORS);
    try
    {
        handle=LoadLibraryEx(fixed,NULL,LOAD_WITH_ALTERED_SEARCH_PATH);
        if (!handle)
             handle=LoadLibraryEx(fixed,NULL,0);
        if (!handle)
            throw SAMLException(string("SAMLConfig::saml_register_extension() unable to load extension library: ") + fixed);
        FARPROC fn=GetProcAddress(handle,"saml_extension_init");
        if (!fn)
            throw SAMLException(string("SAMLConfig::saml_register_extension() unable to locate saml_extension_init entry point: ") + fixed);
        if (reinterpret_cast<int(*)(void*)>(fn)(context)!=0)
            throw SAMLException(string("SAMLConfig::saml_register_extension() detected error in saml_extension_init: ") + fixed);
        if (fixed!=path)
            free(fixed);
        SetErrorMode(em);
    }
    catch(...)
    {
        if (handle)
            FreeLibrary(handle);
        SetErrorMode(em);
        if (fixed!=path)
            free(fixed);
        throw;
    }

#elif defined(HAVE_DLFCN_H)
    void* handle=dlopen(path,RTLD_LAZY);
    if (!handle)
        throw SAMLException(string("SAMLConfig::saml_register_extension unable to load extension library '") + path + "': " + dlerror());
    int (*fn)(void*)=(int (*)(void*))(dlsym(handle,"saml_extension_init"));
    if (!fn)
    {
        dlclose(handle);
        throw SAMLException(string("SAMLConfig::saml_register_extension unable to locate saml_extension_init entry point in '") + path + "': " + (dlerror() ? dlerror() : "unknown error"));
    }
    try
    {
        if (fn(context)!=0)
            throw SAMLException(string("SAMLConfig::saml_register_extension() detected error in saml_extension_init in ") + path);
    }
    catch(...)
    {
        if (handle)
            dlclose(handle);
        throw;
    }
#else
# error "Don't know about dynamic loading on this platform!"
#endif
    m_libhandles.push_back(handle);
    m_log->info("loaded extension: %s",path);
}
