import sys, os, pwd, tempfile, shutil
from decorator import decorator
import warnings
import errno
import threading, copy

from bazjunk.path import makedirs

from . import dependencies
from .benchmark import benchmarking

cache_state = threading.local()

# Performs comparably to jsonlib and marshal in my testing.
from cPickle import dump, load
SUFFIX = '.pickle'

def cache_dir():
    from . import custom
    return "/tmp/%s.%s.bazki/cache" % (pwd.getpwuid(os.getuid())[0],
                                       custom.APP_NAME)

# These files store data of the form:
# {'revision': 17,
#  'dependencies': [(ename, propname, hash), ...],
#  'formats': {'txt': {'value': ..., 'metadata': ..., 'dependencies': ...},
#              '.tex': {...}, ...}}
#
# _metadata files' formats dictionary is more like metadata, of the
# form: {'final_map': {...}}
# TODO(xavid): storing dependencies this way is redundant
def cache_file_name(ename, prop_name):
    if ename is None:
        assert prop_name is None
        return os.path.join(cache_dir(), '_' + SUFFIX)
    else:
        return os.path.join(cache_dir(), ename,
                            (prop_name or '_') + SUFFIX).encode('utf-8')

def cache_data_for(ename, prop_name):
    pvt = (ename, prop_name)
    # See if we have something prevalidated
    if pvt in cache_state.validated:
        return cache_state.validated[pvt]
    # TODO(xavid): check cache_state.transient

    try:
        with open(cache_file_name(ename, prop_name)) as fil:
            ret = load(fil)
    except IOError:
        return None
    else:
        #with benchmarking('validating cache entry for %s.%s'
        #                  % (ename, prop_name)):
        if True:
            from . import db
            # SVN revision short-circuit.  <None> means no revision
            # shortcutting possible; i.e., no version control hook.
            rev = db.get_revision()

            if rev is not None and ret['revision'] == rev:
                cache_state.validated[pvt] = ret
                return ret
            
            for dep in ret['dependencies']:
                if dep == dependencies.OMNISCIENCE:
                    continue
                assert len(dep) == 3, dep
                dename,dpname,hsh = dep
                from . import structure
                delement = structure.get_element(dename)
                if delement is None:
                    if dpname != '__exists':
                        return None
                else:
                    if dpname == '__parent':
                        if hsh != dependencies.get_hash(
                            delement.get_parent_ename()):
                            return None
                    elif dpname == '__children':
                        if hsh != dependencies.get_seq_hash(
                            delement.get_children()):
                            return None
                    elif dpname == '__propvals':
                        if hsh != dependencies.get_seq_hash(
                            delement.list_props()):
                            return None
                    elif dpname == '__fragile':
                        return None
                    elif dpname == '__exists':
                        return None
                    else:
                        dpropval = delement.get_propval(dpname)
                        if dpropval is None:
                            # None is the NoPropvalDep
                            if hsh is not None:
                                return None
                        else:
                            if hsh != dependencies.get_hash(dpropval.value):
                                return None
            # Put in transient so it'll get rewritten with the new revision.
            pvt = (ename, prop_name)
            if pvt in cache_state.transient:
                for f in ret['formats']:
                    if f not in cache_state.transient[pvt]:
                        cache_state.transient[pvt][f] = ret['formats'][f]
            else:
                cache_state.transient[pvt] = ret['formats']
            return ret

def cached_formats(ename, prop_name, cache_data):
    if cache_data is not None:
        ret = set(cache_data['formats'].keys())
    else:
        ret = set()
    pvt = (ename, prop_name)
    if (pvt in cache_state.transient):
        ret.update(cache_state.transient[pvt].keys())
    return ret

def get_from_cache(ename, prop_name, format, cache_data=None):
    pvt = (ename, prop_name)
    # TODO(xavid): move into cache_data_for()
    if (pvt in cache_state.transient
        and format in cache_state.transient[pvt]):
        dct = cache_state.transient[pvt][format]
    else:
        if cache_data is None:
            cache_data = cache_data_for(ename, prop_name)
        if cache_data is None or format not in cache_data['formats']:
            return None
        else:
            dct = cache_data['formats'][format]
    return dct

def cache_propval(ename, prop_name, format, value, deps, metadata,
                  cache_data=None, cache_tag=''):
    for d in deps:
        if len(d) != 3 and d != dependencies.OMNISCIENCE:
            invalidate_cache()
            assert False, d
    ftrim = format.split('|', 1)[0]
    pvt = (ename, prop_name)
    if pvt not in cache_state.transient:
        if cache_data is None:
            cache_data = cache_data_for(ename, prop_name)
            if cache_data is None:
                cache_state.transient[pvt] = {}
            else:
                cache_state.transient[pvt] = dict(cache_data['formats'])
    cache_state.transient[pvt][format+cache_tag] = dict(
        value=value, metadata=metadata, dependencies=frozenset(deps))

def invalidate_cache():
    cache_state.transient = {}
    cache_state.validated = {}
    try:
        shutil.rmtree(cache_dir())
    except OSError, e:
        if e.errno == errno.ENOENT:
            pass
        else:
            raise

# Keyspace defined by structure.py and wiki.py
def set_memo(key, value):
    cache_state.memo[key] = value

def get_memo(key, default=None):
    return cache_state.memo.get(key, default)

class cache_hook(object):
    @staticmethod
    def begin():
        cache_hook.clear()

    @staticmethod
    def clear():
        cache_state.transient = {}
        cache_state.validated = {}
        cache_state.memo = {}
        
    @staticmethod
    def commit():
        from . import db
        for ename, prop_name in cache_state.transient:
            destfile = cache_file_name(ename, prop_name)
            deps = set()
            formats = {}
            pvt = (ename, prop_name)
            for format in cache_state.transient[pvt]:
                formats[format] = dict(cache_state.transient[
                    pvt][format])
                deps.update(formats[format]['dependencies'])
            assert dependencies.DISCORDIA not in deps
            fil = tempfile.NamedTemporaryFile(delete=False)
            dump({'dependencies': deps,
                  'formats': formats,
                  'revision': db.get_revision()}, fil.file)
            fil.close()
            makedirs(os.path.dirname(destfile))
            os.rename(fil.name, destfile)
        
    @staticmethod
    def abort():
        del cache_state.transient
        del cache_state.validated
        del cache_state.memo

METADATA = '_metadata_'
METADATA_FMT = 'md'
def metadata_cached(key):
    def metadata_helper(func, element, deps=None):
        assert deps is None
        cache_entry = get_from_cache(element.ename, METADATA + key,
                                     METADATA_FMT)
        if cache_entry is None:
            deps = dependencies.Dependencies()
            with benchmarking('calculating %s metadata for %s'
                              % (key, element.ename)):
                value = func(element, deps)
                cache_propval(element.ename, METADATA + key, METADATA_FMT,
                              value, deps, {})
        else:
            value = cache_entry['value']
        return value
    return decorator(metadata_helper)
