/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8 -*- */

/*
 *  Copyright (C) 2008 OMC Denmark ApS.
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Affero General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Affero General Public License for more details.
 *
 *  You should have received a copy of the GNU Affero General Public License
 *  along with this program. If not, see <http://www.gnu.org/licenses/>.
 */

#include "logging.h"
#include "globals.h"

#include <windows.h>
#include <winsock2.h>
#include <process.h>
#include <lm.h>
#include <dsgetdc.h>

static char*
get_domain_name(void)
{
        char *retv = NULL;
        PDOMAIN_CONTROLLER_INFO dc_info = NULL;

        if (NO_ERROR == DsGetDcName(NULL, NULL, NULL, NULL, DS_DIRECTORY_SERVICE_PREFERRED, &dc_info)) {
                retv = _strdup(dc_info->DomainName);
                NetApiBufferFree(dc_info);
        } else if (NO_ERROR == DsGetDcName(NULL, NULL, NULL, NULL, DS_DIRECTORY_SERVICE_PREFERRED | DS_FORCE_REDISCOVERY, &dc_info)) {
                retv = _strdup(dc_info->DomainName);
                NetApiBufferFree(dc_info);
        }

        return retv;
}

static char*
get_domain_controller(const char *host,
                      const char *domain)
{
        char *retv = NULL;
        PDOMAIN_CONTROLLER_INFO dc_info = NULL;

        if (NO_ERROR == DsGetDcName(host, domain, NULL, NULL, DS_DIRECTORY_SERVICE_PREFERRED, &dc_info)) {
                retv = _strdup(dc_info->DomainControllerName);
                NetApiBufferFree(dc_info);
        } else if (NO_ERROR == DsGetDcName(host, domain, NULL, NULL, DS_DIRECTORY_SERVICE_PREFERRED | DS_FORCE_REDISCOVERY, &dc_info)) {
                retv = _strdup(dc_info->DomainControllerName);
                NetApiBufferFree(dc_info);
        }

        return retv;
}

static void
get_file_info(const LPVOID fileinfo,
              const char *token,
              void **value)
{
        unsigned short *translation = NULL;
        UINT len = 0;

        *value = NULL;
        if (!VerQueryValue(fileinfo, "\\VarFileInfo\\Translation", (void **)&translation, &len))
                return;

        {
                DWORD lang = 0;
                char info_query[256] = {'\0'};
                char *versionInfo = NULL;

                if (translation && (4 == len)) {
                        memcpy(&lang, translation, 4);
                        sprintf_s(info_query,
                                  sizeof(info_query),
                                  "\\StringFileInfo\\%02X%02X%02X%02X\\%s",
                                  (lang & 0x0000ff00)>>8,
                                  (lang & 0x000000ff),
                                  (lang & 0xff000000)>>24,
                                  (lang & 0x00ff0000)>>16,
                                  token);
                } else
                        sprintf_s(info_query,
                                  sizeof(info_query),
                                  "\\StringFileInfo\\%04X04B0\\%s",
                                  GetUserDefaultLangID(),
                                  token);
                len = 0;
                VerQueryValue(fileinfo, info_query, value, &len);
        }
}

static bool
get_program_files_dir(char dest[], 
		      DWORD size)
{
	if (!dest || !size)
		return false;

	HKEY key;
	LONG ret = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
				"SOFTWARE\\Microsoft\\Windows\\CurrentVersion",
				0,
				KEY_READ, 
				&key);
	if (ERROR_SUCCESS != ret)
		return false;

	DWORD type;
	ret = RegQueryValueEx(key, 
			      "ProgramFilesDir", 
			      NULL,
			      &type,
			      (BYTE*)dest,
			      &size);
	RegCloseKey(key);

	return (ERROR_SUCCESS == ret) ? true : false;
}


static char*
get_mapi_version(void)
{
        char *retv = NULL;
        char mapi_version[256] = {'\0'};
        DWORD size = 0;
        DWORD tmp = 0;
        void *data = NULL;
        void *product = NULL;
        void *version = NULL;
        const char *dll_path_1 = "C:\\WINDOWS\\system32\\EMSMDB32.DLL";
        const char *dll_path_2 = "C:\\WINNT\\system32\\EMSMDB32.DLL";
        const char *dll_path_3 = "EMSMDB32.DLL";
	const char *dll_path = dll_path_1;

	char default_path[512] = { '\0' };
	if (get_program_files_dir(default_path, sizeof(default_path))) {
		strcat_s(default_path, sizeof(default_path), "\\ExchangeMapi\\EMSMDB32.DLL");
		size = GetFileVersionInfoSize(default_path, &tmp);
	}
	if (size) {
		dll_path = default_path;
	} else	{
		do {
			dll_path = dll_path_1;
			size = GetFileVersionInfoSize(dll_path_1, &tmp);
			if (size) 
				break;

			dll_path = dll_path_2;
			size = GetFileVersionInfoSize(dll_path_2, &tmp);
			if (size) 
				break;

			dll_path = dll_path_3;
			size = GetFileVersionInfoSize(dll_path_3, &tmp);
			if (size)
				break;

			return NULL;
		} while (0);
	}
        data = malloc((size_t)size);
        if (!data)
                return NULL;

        if (!GetFileVersionInfo(dll_path, 0, size, data))
                goto out;

        get_file_info(data, "ProductName", &product);
        if (!product)
                goto out;

        get_file_info(data, "ProductVersion", &version);
        if (!version)
                goto out;

        sprintf_s(mapi_version, sizeof(mapi_version), "%s %s", (char*)product, (char*)version);
        retv = _strdup(mapi_version);
out:
        free(data);
        return retv;
}

LogManager::LogManager()
        : log_stream_(NULL),
          output_stream_(NULL)
{
}

LogManager::~LogManager()
{
        if (log_stream_)
                log_stream_->close();
        delete log_stream_;
}

void
LogManager::redirectToSyslog(const ACE_TCHAR *prog_name)
{
        ACE_LOG_MSG->open(prog_name, ACE_Log_Msg::SYSLOG, prog_name);

        this->writeHeader();
}

void
LogManager::redirectToOStream(ACE_OSTREAM_TYPE *output)
{
        output_stream_ = output;
        ACE_LOG_MSG->msg_ostream(this->output_stream_);
        ACE_LOG_MSG->clr_flags(ACE_Log_Msg::STDERR | ACE_Log_Msg::LOGGER);
        ACE_LOG_MSG->set_flags(ACE_Log_Msg::OSTREAM);

        this->writeHeader();
}

void
LogManager::redirectToFile(const char *filename)
{
        log_stream_ = new std::ofstream();
        log_stream_->open(filename, ios::out | ios::app);
        this->redirectToOStream(log_stream_);

        this->writeHeader();
}

void
LogManager::redirectToStderr(void)
{
        ACE_LOG_MSG->clr_flags(ACE_Log_Msg::OSTREAM | ACE_Log_Msg::LOGGER);
        ACE_LOG_MSG->set_flags(ACE_Log_Msg::STDERR);

        this->writeHeader();
}

void
LogManager::writeHeader(void)
{
        ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Brutus Server version \"%s\"\n"), BRUTUS_VERSION_STRING));

        char host_name[HOST_NAME_MAX]= { '\0' }; 
        if (gethostname(host_name, HOST_NAME_MAX)) {
                host_name[0] = '\0';
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Host name could not be determined\n")));
        } else
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Host name \"%s\"\n"), host_name));

        ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Brutus is now trying to determine your domain name. \
If this is the last message you hear from Brutus, then you most \
likely have problems with your domain controller.\n")));

        char *domain_name = get_domain_name();
        if (!domain_name)
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Domain name could not be determined\n")));
        else
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Domain name \"%s\"\n"), domain_name));

        char *domain_controller = NULL;
        if (host_name[0] && domain_name)
                domain_controller = get_domain_controller(host_name, domain_name);
        if (!domain_controller)
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Domain controller could not be determined\n")));
        else
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - Domain controller \"%s\"\n"), domain_controller));

        char *mapi_version = get_mapi_version();
        if (!mapi_version)
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - MAPI version could not be determined\n")));
        else
                ACE_DEBUG((LM_INFO, ACE_TEXT("<%D>%N:%l - MAPI version \"%s\"\n"), mapi_version));

        if (domain_name)
                free(domain_name);
        if (domain_controller)
                free(domain_controller);
        if (mapi_version)
                free(mapi_version);
}
