import time
import thread, threading
import MySQLdb

class Database(object):
    _cache = {}
    def __new__(cls, config_filename='/mit/freeculture/.my.cnf'):
        key = (config_filename, thread.get_ident())
        if key not in cls._cache:
            print "Connecting to database", key
            cls._cache[key] = DatabaseImpl(config_filename)
        return cls._cache[key]

class DatabaseImpl(object):
    def __init__(self, config_filename='/mit/freeculture/.my.cnf'):
        self.config_filename = config_filename
        self.db_name = 'freeculture+youtomb'
        self.sites_by_name = {}
        self.sites_by_id = {}
        self.lock = threading.RLock()
        self._connect()

    def _connect(self):
        self.connection = MySQLdb.connect(
            read_default_file=self.config_filename, read_default_group="mysql",
            db=self.db_name,
            use_unicode=True, charset='utf8')
        self.connection.autocommit(True)

    def cursor(self):
        return self.connection.cursor(MySQLdb.cursors.DictCursor)
    
    def begin_transaction(self):
        retries = 3
        while True:
            try:
                return self.connection.autocommit(False)
                break
            except (MySQLdb.ProgrammingError, MySQLdb.OperationalError), e:
                # usually 'server has gone away' or 'lost connection to server';
                # sometimes the client automatically reconnects, sometimes it doesn't,
                # so we just reconnect.  might fix some other exns too, if they happen.
                retries = retries - 1
                if retries:
                    print "reconnecting after MySQL error %s" % (e,)
                    self._connect()
                else:
                    raise RuntimeError("got MySQL error %s, failed to recover" % (e,))
    
    def commit_transaction(self):
        try:
            r = self.connection.commit()
        finally:
            self.connection.autocommit(True)
        return r

    def _execute(self, query, params=()):
        try:
            self.lock.acquire()
            c = self.cursor()
            retries = 3
            while True:
                try:
                    n = c.execute(query, params)
                    break
                except (MySQLdb.ProgrammingError, MySQLdb.OperationalError), e:
                    # usually 'server has gone away' or 'lost connection to server';
                    # sometimes the client automatically reconnects, sometimes it doesn't,
                    # so we just reconnect.  might fix some other exns too, if they happen.
                    retries = retries - 1
                    if retries:
                        print "reconnecting after MySQL error %s" % (e,)
                        self._connect()
                        c = self.cursor()
                    else:
                        raise RuntimeError("got MySQL error %s, failed to recover" % (e,))
            return c, n
        finally:
            self.lock.release()

    def execute(self, query, params=()):
        _, n = self._execute(query, params)
        return n

    def allrows(self, query, params=()):
        c, _ = self._execute(query, params)
        return c.fetchall()

    def onerow(self, query, params=()):
        c, n = self._execute(query, params)
        if n != 1:
            raise RuntimeError("onerow query '%s' got %d != 1 rows" % (query, n))
        return c.fetchone()
        
    def site_id_for_name(self, site):
        if not self.sites_by_name:
            rows = self.allrows("SELECT id, name FROM sites")
            for r in rows:
                self.sites_by_name[r["name"].lower()] = r["id"]
        return self.sites_by_name[site.lower()]
    
    def site_name_for_id(self, site):
        if not self.sites_by_id:
            rows = self.allrows("SELECT name, id FROM sites")
            for r in rows:
                self.sites_by_id[r["id"]] = r["name"]
        return self.sites_by_id[site]
