import sys, os, tempfile, shutil
from decorator import decorator
import warnings
import errno
import threading, copy
import hashlib

import socket

from bazjunk.path import makedirs

from . import dependencies
from .benchmark import benchmarking

cache_state = threading.local()

# Performs comparably to jsonlib and marshal in my testing.
from cPickle import dump, load
INFO_SUFFIX = '.pickle'
SPECIFIC_SUFFIX = '.pickle'

def get_suffix(format, cache_tag):
    assert not cache_tag or cache_tag.startswith('^'), cache_tag
    if format.startswith('.'):
        return cache_tag + format
    else:
        return cache_tag + '>' + format

def format_from_suffix(suffix, cache_tag=None):
    if '.' in suffix:
        tag, format = suffix.rsplit('.', 1)
        format = '.' + format
    else:
        assert '>' in suffix, suffix
        tag, format = suffix.rsplit('>', 1)
    if cache_tag is not None and tag != cache_tag:
        return None
    return format


def cache_dir():
    from . import custom
    return "%s/cache" % (custom.get_tmp_dir())

# Caching stores data in the following files:
#
# foo.pickle: a pickled dict like:
#   {'revision': 17,
#    'dependencies': [(ename, propname, hash), ...]}
# foo.html, foo.tex, foo>html: the actual data
# foo.html.pickle, etc: a pickled dict like:
#   {'metadata': ...}
#
# foo.format.pickle will have a 'value' entry, and foo.format will not exist,
# if the value was not of type unicode/str.
#
# _metadata files' formats dictionary is more like metadata, of the
# form: {'final_map': {...}}
def cache_file_stem(ename, prop_name):
    if ename is None:
        assert prop_name is None
        return os.path.join(cache_dir(), '_')
    else:
        return os.path.join(cache_dir(), ename,
                            (prop_name or '_')).encode('utf-8')

# Returns None if invalid.  Otherwise, it returns
# (deps, needs_to_be_rewritten_to_disk)
def get_deps_info(ename, prop_name):
    stem = cache_file_stem(ename, prop_name)
    try:
        with open(stem + INFO_SUFFIX) as fil:
            info = load(fil)
    except EOFError:
        benchmarking.info('EOF reading ' + stem + INFO_SUFFIX)
        return None
    except IOError:
        benchmarking.info('IOError reading ' + stem + INFO_SUFFIX)
        return None
    else:
        from . import db
        # SVN revision short-circuit.  <None> means no revision
        # shortcutting possible; i.e., no version control hook.
        rev = db.get_revision()

        deps = info['dependencies']

        if rev is not None and info['revision'] == rev:
            # Short circuit validation based on svn revision
            return deps, False
        else:
            for dep in deps:
                if dep == dependencies.OMNISCIENCE:
                    continue
                if dep == dependencies.REVISION:
                    return None
                assert len(dep) == 3, dep
                dename,dpname,hsh = dep
                from . import structure
                delement = structure.get_element(dename)
                if delement is None:
                    if dpname != '__exists':
                        return None
                else:
                    if dpname == '__parent':
                        if hsh != dependencies.get_hash(
                            delement.get_parent_ename()):
                            return None
                    elif dpname == '__children':
                        if hsh != dependencies.get_seq_hash(
                            delement.get_children()):
                            return None
                    elif dpname == '__propvals':
                        if hsh != dependencies.get_seq_hash(
                            delement.list_props()):
                            return None
                    elif dpname == '__fragile':
                        return None
                    elif dpname == '__exists':
                        return None
                    else:
                        dpropval = delement.get_propval(dpname)
                        if dpropval is None:
                            # None is the NoPropvalDep
                            if hsh is not None:
                                return None
                        else:
                            if hsh != dependencies.get_hash(dpropval.value):
                                return None
    return deps, True

# Returns something like:
# {'value': ..., 'metadata': ..., 'dependencies': ...}
def cache_data_for(ename, prop_name, suffix, deps_info=False):
    pvt = (ename, prop_name)
    # See if we have something prevalidated
    if pvt in cache_state.validated and suffix in cache_state.validated[pvt]:
        return cache_state.validated[pvt][suffix]
    # TODO(xavid): check transient

    #with benchmarking('validating cache entry for %s.%s'
    #                  % (ename, prop_name)):
    if deps_info is False:
        deps_info = get_deps_info(ename, prop_name)
    if deps_info is None:
        #print >>sys.stderr, 'no deps_info', ename, prop_name
        return None
    else:
        deps, rewrite = deps_info
        ret = {'dependencies': deps}
        stem = cache_file_stem(ename, prop_name)
        try:
            with open(stem + suffix) as fil:
                val = fil.read()
                import conversion
                if format_from_suffix(suffix) in conversion.UNICODE_FORMATS:
                    val = unicode(val, 'utf-8')
                ret['value'] = val
        except IOError:
            pass
        try:
            with open(stem + suffix + SPECIFIC_SUFFIX) as fil:
                specific = load(fil)
                ret['metadata'] = specific['metadata']
                if 'value' not in ret:
                    ret['value'] = specific['value']
                else:
                    assert 'value' not in specific, specific
        except IOError:
            benchmarking.info('IOError ' + stem + suffix
                              + SPECIFIC_SUFFIX)
            return None
        pvt = (ename, prop_name)
        if rewrite:
            # Put in transient so it'll get rewritten with the new revision.
            if pvt in cache_state.transient:
                cache_state.transient[pvt][suffix] = ret
            else:
                cache_state.transient[pvt] = {suffix: ret}
        else:
            if pvt in cache_state.validated:
                cache_state.validated[pvt][suffix] = ret
            else:
                cache_state.validated[pvt] = {suffix: ret}
        return ret

def cached_formats(ename, prop_name, cache_tag):
    import glob
    assert not cache_tag or cache_tag[0] == '^', cache_tag
    ret = set()
    deps_info = get_deps_info(ename, prop_name)
    if deps_info is not None:
        stem = cache_file_stem(ename, prop_name)
        for fname in glob.glob(stem + cache_tag + '.*' + SPECIFIC_SUFFIX):
            fname = os.path.basename(fname)[:-len(SPECIFIC_SUFFIX)]
            stem, suf = fname.split('.', 1)
            if '.' not in suf and '.' + suf != INFO_SUFFIX:
                ret.add('.' + suf)
        for fname in glob.glob(stem + cache_tag + '>*' + SPECIFIC_SUFFIX):
            fname = os.path.basename(fname)[:-len(SPECIFIC_SUFFIX)]
            stem, suf = fname.split('>', 1)
            if '.' not in suf:
                ret.add(suf)
    pvt = (ename, prop_name)
    if (pvt in cache_state.transient):
        for suffix in cache_state.transient[pvt]:
            fmt = format_from_suffix(suffix, cache_tag)
            if fmt:
                ret.add(fmt)
    return ret, deps_info

def get_from_cache(ename, prop_name, format, cache_tag='', deps_info=False):
    pvt = (ename, prop_name)
    suffix = get_suffix(format, cache_tag)
    # TODO(xavid): move into cache_data_for()
    if (pvt in cache_state.transient
        and suffix in cache_state.transient[pvt]):
        cache_data = cache_state.transient[pvt][suffix]
        assert 'metadata' in cache_data, cache_data
    else:
        cache_data = cache_data_for(ename, prop_name, suffix,
                                    deps_info=deps_info)
        assert cache_data is None or 'metadata' in cache_data, cache_data
    return cache_data

def cache_propval(ename, prop_name, format, value, deps, metadata,
                  cache_data=None, cache_tag=''):
    for d in deps:
        if len(d) != 3 and d not in (dependencies.OMNISCIENCE,
                                     dependencies.REVISION):
            invalidate_cache()
            assert False, d
    pvt = (ename, prop_name)
    suffix = get_suffix(format, cache_tag)
    if pvt not in cache_state.transient:
        if cache_data is None:
            cache_data = cache_data_for(ename, prop_name, suffix)
            if cache_data is None:
                cache_state.transient[pvt] = {}
            else:
                cache_state.transient[pvt] = {suffix: cache_data}
    cache_state.transient[pvt][suffix] = dict(
        value=value, metadata=metadata, dependencies=frozenset(deps))

def invalidate_cache():
    cache_state.transient = {}
    cache_state.validated = {}
    try:
        shutil.rmtree(cache_dir())
    except OSError, e:
        if e.errno == errno.ENOENT:
            pass
        else:
            raise

# Keyspace defined by structure.py and wiki.py
def set_memo(key, value):
    cache_state.memo[key] = value

def get_memo(key, default=None):
    return cache_state.memo.get(key, default)

class cache_hook(object):
    @staticmethod
    def begin():
        cache_hook.clear()

    @staticmethod
    def clear():
        cache_state.transient = {}
        cache_state.validated = {}
        cache_state.memo = {}
        
    @staticmethod
    def commit():
        from . import db, conversion, custom
        for ename, prop_name in cache_state.transient:
            deststem = cache_file_stem(ename, prop_name)
            deps = set()
            formats = {}
            pvt = (ename, prop_name)
            for suffix in cache_state.transient[pvt]:
                formats[suffix] = dict(cache_state.transient[
                    pvt][suffix])
                deps.update(formats[suffix]['dependencies'])
            assert dependencies.DISCORDIA not in deps
            makedirs(os.path.dirname(deststem))
            for suffix in formats:
                val = formats[suffix]['value']
                specdic = {'metadata': formats[suffix]['metadata']}
                if isinstance(val, unicode) or isinstance(val, str):
                    if isinstance(val, unicode):
                        assert (format_from_suffix(suffix)
                                in conversion.UNICODE_FORMATS), (val, suffix)
                        val = val.encode('utf-8')
                    else:
                        assert (format_from_suffix(suffix)
                                not in conversion.UNICODE_FORMATS), (
                            val, suffix)
                    with benchmarking('writing ' + deststem + suffix):
                        fil = tempfile.NamedTemporaryFile(
                            delete=False, dir=custom.get_tmp_dir())
                        fil.write(val)
                        fil.close()
                else:
                    fil = None
                    specdic['value'] = val
                with benchmarking('dumping '
                                  + deststem + suffix + SPECIFIC_SUFFIX):
                    specific = tempfile.NamedTemporaryFile(
                        delete=False, dir=custom.get_tmp_dir())
                    dump(specdic, specific.file)
                    specific.close()
                if fil is not None:
                    os.rename(fil.name, deststem + suffix)
                else:
                    try:
                        os.remove(deststem + suffix)
                    except OSError:
                        pass
                os.rename(specific.name,
                          deststem + suffix + SPECIFIC_SUFFIX)
            with benchmarking('dumping ' + deststem + INFO_SUFFIX):
                fil = tempfile.NamedTemporaryFile(delete=False,
                                                  dir=custom.get_tmp_dir())
                dump({'dependencies': deps,
                      'revision': db.get_revision()}, fil.file)
                fil.close()
            os.rename(fil.name, deststem + INFO_SUFFIX)
        
    @staticmethod
    def abort():
        del cache_state.transient
        del cache_state.validated
        del cache_state.memo

METADATA = '_metadata_'
METADATA_FMT = 'md'
def metadata_cached(key):
    def metadata_helper(func, element, deps=None):
        assert deps is None
        cache_entry = get_from_cache(element.ename, METADATA + key,
                                     METADATA_FMT)
        if cache_entry is None:
            deps = dependencies.Dependencies()
            with benchmarking('calculating %s metadata for %s'
                              % (key, element.ename)):
                value = func(element, deps)
                cache_propval(element.ename, METADATA + key, METADATA_FMT,
                              value, deps, {})
        else:
            value = cache_entry['value']
        return value
    return decorator(metadata_helper)

def args_cached(key, serialize=lambda i:i, unserialize=lambda i:i):
    assert ':' not in key
    def args_helper(func, *args, **kw):
        # This is technically lossy, but easier to debug than something pickled
        # and is unlikely to collide in practice.
        # TODO(xavid): not really easier to debug, we're hashing it
        cachekey = '_' + key + ':'
        argbit = ', '.join(repr(a) for a in args)
        if kw:
            argbit += ', ' + ', '.join(k + '=' + repr(kw[k]) for k in kw)
        cachekey += hashlib.md5(argbit).hexdigest()[:16]
        cache_entry = get_from_cache(None, None, format=cachekey)

        if cache_entry is None:
            with benchmarking('calculating %s args_cached metadata'
                              % (cachekey)):
                value = func(*args, **kw)
                deps = dependencies.Dependencies()
                deps.addRevisionDep()
                cache_propval(None, None, cachekey, serialize(value), deps, {})
        else:
            value = unserialize(cache_entry['value'])
        return value
    return decorator(args_helper)
            
