class KeggModel(object):
    
    def __del__(self):
        self.ccache.dump()
    
    def __init__(self, S, cids, rids=None):
        self.S = S
        self.cids = cids
        self.rids = rids
        assert len(self.cids) == self.S.shape[0]
        if self.rids is not None:
            assert len(self.rids) == self.S.shape[1]
        self.ccache = CompoundCacher()

        # remove H+ from the stoichiometric matrix if it exists
        if 'C00080' in self.cids:
            i = self.cids.index('C00080')
            self.S = np.vstack((self.S[:i,:], self.S[i+1:,:]))
            self.cids.pop(i)
    

    @staticmethod
    def from_file(fname, arrow='<=>', format='kegg', has_reaction_ids=False):
        """
        reads a file containing reactions in KEGG format
        
        Arguments:
           fname            - the filename to read
           arrow            - the string used as the 'arrow' in each reaction (default: '<=>')
           format           - the text file format provided ('kegg', 'tsv' or 'csv')
           has_reaction_ids - a boolean flag indicating if there is a column of
                              reaction IDs (separated from the reaction with
                              whitespaces)
        
        Return a KeggModel
        """
        fd = open(fname, 'r')
        if format == 'kegg':
            model = KeggModel.from_formulas(fd.readlines(), arrow, has_reaction_ids)
        elif format == 'tsv':
            model = KeggModel.from_csv(fd, has_reaction_ids=has_reaction_ids, delimiter='\t')
        elif format == 'csv':
            model = KeggModel.from_csv(fd, has_reaction_ids=has_reaction_ids, delimiter=None)
        fd.close()
        return model
    
    @staticmethod
    def from_csv(fd, has_reaction_ids=True, delimiter=None):
        csv_reader = csv.reader(fd, delimiter=delimiter)
        if has_reaction_ids:
            rids = csv_reader.next()
            rids = rids[1:]
        else:
            rids = None
        S = []
        cids = []
        for i, row in enumerate(csv_reader):
            cids.append(row[0])
            S.append([float(x) for x in row[1:]])
        S = np.array(S)

        return KeggModel(S, cids, rids)
    
    @staticmethod
    def from_kegg_reactions(kegg_reactions, has_reaction_ids=False):
        if has_reaction_ids:
            rids = [r.rid for r in kegg_reactions]
        else:
            rids = None

        cids = set()
        for reaction in kegg_reactions:
            cids = cids.union(reaction.keys())
        
        # convert the list of reactions in sparse notation into a full
        # stoichiometric matrix, where the rows (compounds) are according to the
        # CID list 'cids'.
        cids = sorted(cids)
        S = np.matrix(np.zeros((len(cids), len(kegg_reactions))))
        for i, reaction in enumerate(kegg_reactions):
            S[:, i] = np.matrix(reaction.dense(cids))
        
        logging.debug('Successfully loaded %d reactions (involving %d unique compounds)' %
                      (S.shape[1], S.shape[0]))
        return KeggModel(S, cids, rids)
    
    @staticmethod
    def from_formulas(reaction_strings, arrow='<=>', has_reaction_ids=False,
                      raise_exception=False):
        """
        parses a list of reactions in KEGG format
        
        Arguments:
           reaction_strings - a list of reactions in KEGG format
           arrow            - the string used as the 'arrow' in each reaction (default: '<=>')
           has_reaction_ids - a boolean flag indicating if there is a column of
                              reaction IDs (separated from the reaction with
                              whitespaces)
        
        Return values:
           S     - a stoichiometric matrix
           cids  - the KEGG compound IDs in the same order as the rows of S
        """
        try:
            reactions = []
            not_balanced_count = 0
            for line in reaction_strings:
                rid = None
                if has_reaction_ids:
                    tokens = re.findall('(\w+)\s+(.*)', line.strip())[0]
                    rid = tokens[0]
                    line = tokens[1]
                try:
                    reaction = KeggReaction.parse_formula(line, arrow, rid)
                except KeggParseException as e:
                    logging.warning(str(e))
                    reaction = KeggReaction({})
                if not reaction.is_balanced(fix_water=True, raise_exception=raise_exception):
                    not_balanced_count += 1
                    logging.warning('Model contains an unbalanced reaction: ' + line)
                    reaction = KeggReaction({})
                reactions.append(reaction)
                logging.debug('Adding reaction: ' + reaction.write_formula())
            
            if not_balanced_count > 0:
                warning_str = '%d out of the %d reactions are not chemically balanced' % \
                              (not_balanced_count, len(reaction_strings))
                logging.debug(warning_str)
            return KeggModel.from_kegg_reactions(reactions, has_reaction_ids)
        
        except ValueError as e:
            if raise_exception:
                raise e
            else:
                logging.debug(str(e))
                return None

    def add_thermo(self, cc):
        # check that all CIDs in the reaction are already cached by CC
        Nc, Nr = self.S.shape
        reactions = []
        for j in xrange(Nr):
            sparse = {self.cids[i]:self.S[i,j] for i in xrange(Nc)
                      if self.S[i,j] != 0}
            reaction = KeggReaction(sparse)
            reactions.append(reaction)
            
        self.dG0, self.cov_dG0 = cc.get_dG0_r_multi(reactions)
        
    def get_transformed_dG0(self, pH, I, T):
        """
            returns the estimated dG0_prime and the standard deviation of
            each estimate (i.e. a measure for the uncertainty).
        """
        dG0_prime = self.dG0 + self._get_transform_ddG0(pH=pH, I=I, T=T)
        dG0_std = np.matrix(np.sqrt(np.diag(self.cov_dG0))).T
        U, s, V = np.linalg.svd(self.cov_dG0, full_matrices=True)
        sqrt_Sigma = np.matrix(U) * np.matrix(np.diag(s**0.5)) * np.matrix(V)
        return dG0_prime, dG0_std, sqrt_Sigma

    def _get_transform_ddG0(self, pH, I, T):
        """
        needed in order to calculate the transformed Gibbs energies of the 
        model reactions.
        
        Returns:
            an array (whose length is self.S.shape[1]) with the differences
            between DrG0_prime and DrG0. Therefore, one must add this array
            to the chemical Gibbs energies of reaction (DrG0) to get the 
            transformed values
        """
        ddG0_compounds = np.matrix(np.zeros((self.S.shape[0], 1)))
        for i, cid in enumerate(self.cids):
            comp = self.ccache.get_compound(cid)
            ddG0_compounds[i, 0] = comp.transform_pH7(pH, I, T)
        
        ddG0_forward = np.dot(self.S.T, ddG0_compounds)
        return ddG0_forward
        
    def check_S_balance(self):
        elements, Ematrix = self.ccache.get_element_matrix(self.cids)
        conserved = Ematrix.T * self.S
        rxnFil = np.any(conserved[:,range(self.S.shape[1])],axis=0)
        unbalanced_ind = np.nonzero(rxnFil)[1]
        if unbalanced_ind != []:
            logging.warning('There are (%d) unbalanced reactions in S. ' 
                            'Setting their coefficients to 0.' % 
                            len(unbalanced_ind.flat))
            if self.rids is not None:
                logging.warning('These are the unbalanced reactions: ' +
                                ', '.join([self.rids[i] for i in unbalanced_ind.flat]))
                    
            self.S[:, unbalanced_ind] = 0
        return self

    def write_reaction_by_index(self, r):
        sparse = dict([(cid, self.S[i, r]) for i, cid in enumerate(self.cids)
                       if self.S[i, r] != 0])
        if self.rids is not None:
            reaction = KeggReaction(sparse, rid=self.rids[r])
        else:
            reaction = KeggReaction(sparse)
        return reaction.write_formula()
        
    def get_unidirectional_S(self):
        S_plus = np.copy(self.S)
        S_minus = np.copy(self.S)
        S_plus[self.S < 0] = 0
        S_minus[self.S > 0] = 0
        return S_minus, S_plus
        
class TrainingData(object):
    
    # a dictionary of the filenames of the training data and the relative 
    # weight of each one    
    FNAME_DICT = {'TECRDB' : ('../data/TECRDB.tsv', 1.0),
                  'FORMATION' : ('../data/formation_energies_transformed.tsv', 1.0),
                  'REDOX' : ('../data/redox.tsv', 1.0)}

    def __del__(self):
        self.ccache.dump()

    def __init__(self):
        self.ccache = CompoundCacher()
        
        thermo_params, self.cids_that_dont_decompose = TrainingData.get_all_thermo_params()
        
        cids = set()
        for d in thermo_params:
            cids = cids.union(d['reaction'].keys())
        cids = sorted(cids)
        
        # convert the list of reactions in sparse notation into a full
        # stoichiometric matrix, where the rows (compounds) are according to the
        # CID list 'cids'.
        self.S = np.zeros((len(cids), len(thermo_params)))
        for k, d in enumerate(thermo_params):
            for cid, coeff in d['reaction'].iteritems():
                self.S[cids.index(cid), k] = coeff
            
        self.cids = cids

        self.dG0_prime = np.array([d['dG\'0'] for d in thermo_params])
        self.T = np.array([d['T'] for d in thermo_params])
        self.I = np.array([d['I'] for d in thermo_params])
        self.pH = np.array([d['pH'] for d in thermo_params])
        self.pMg = np.array([d['pMg'] for d in thermo_params])
        self.weight = np.array([d['weight'] for d in thermo_params])
        self.reference = [d['reference'] for d in thermo_params]
        self.description = [d['description'] for d in thermo_params]
        rxn_inds_to_balance = [i for i in xrange(len(thermo_params))
                               if thermo_params[i]['balance']]

        self.balance_reactions(rxn_inds_to_balance)
        
        self.reverse_transform()

    def savemat(self, fname):
        d = {'dG0_prime': self.dG0_prime,
             'dG0': self.dG0,
             'T': self.T,
             'I': self.I,
             'pH': self.pH,
             'pMg': self.pMg,
             'weight': self.weight,
             'cids': self.cids}
        savemat(fname, d, oned_as='row')

    def savecsv(self, fname):
        csv_output = csv.writer(open(fname, 'w'))
        csv_output.writerow(['reaction', 'T', 'I', 'pH', 'reference', 'dG0', 'dG0_prime'])
        for j in xrange(self.S.shape[1]):
            sparse = {self.cids[i]: self.S[i, j] for i in xrange(self.S.shape[0])}
            r_string = KeggReaction(sparse).write_formula()
            csv_output.writerow([r_string, self.T[j], self.I[j], self.pH[j],
                                 self.reference[j], self.dG0[j], self.dG0_prime[j]])

    @staticmethod
    def str2double(s):
        """
            casts a string to float, but if the string is empty return NaN
        """
        if s == '':
            return np.nan
        else:
            return float(s)

    @staticmethod
    def read_tecrdb(fname, weight):
        """Read the raw data of TECRDB (NIST)"""
        thermo_params = [] # columns are: reaction, dG'0, T, I, pH, pMg, weight, balance?

        headers = ["URL", "REF_ID", "METHOD", "EVAL", "EC", "ENZYME NAME",
                   "REACTION IN KEGG IDS", "REACTION IN COMPOUND NAMES",
                   "K", "K'", "T", "I", "pH", "pMg"]

        for row_list in csv.reader(open(fname, 'r'), delimiter='\t'):
            if row_list == []:
                continue
            row = dict(zip(headers, row_list))
            if (row['K\''] == '') or (row['T'] == '') or (row['pH'] == ''):
                continue
            
            # parse the reaction
            reaction = KeggReaction.parse_formula(row['REACTION IN KEGG IDS'], arrow='=')

            # calculate dG'0
            dG0_prime = -R * TrainingData.str2double(row['T']) * \
                             np.log(TrainingData.str2double(row['K\''])) 
            try:
                thermo_params.append({'reaction': reaction,
                                      'dG\'0' : dG0_prime,
                                      'T': TrainingData.str2double(row['T']), 
                                      'I': TrainingData.str2double(row['I']),
                                      'pH': TrainingData.str2double(row['pH']),
                                      'pMg': TrainingData.str2double(row['pMg']),
                                      'weight': weight,
                                      'balance': True,
                                      'reference': row['REF_ID'],
                                      'description': row['REACTION IN COMPOUND NAMES']})
            except ValueError:
                raise Exception('Cannot parse row: ' + str(row))

        logging.debug('Successfully added %d reactions from TECRDB' % len(thermo_params))
        return thermo_params
        
    @staticmethod
    def read_formations(fname, weight):
        """Read the Formation Energy data"""
        
        # columns are: reaction, dG'0, T, I, pH, pMg, weight, balance?
        thermo_params = []
        cids_that_dont_decompose = set()
        
        # fields are: cid, name, dG'0, pH, I, pMg, T, decompose?,
        #             compound_ref, remark
        for row in csv.DictReader(open(fname, 'r'), delimiter='\t'):
            if int(row['decompose']) == 0:
                cids_that_dont_decompose.add(row['cid'])
            if row['dG\'0'] != '':
                rxn = KeggReaction({row['cid'] : 1})
                thermo_params.append({'reaction': rxn,
                                      'dG\'0' : TrainingData.str2double(row['dG\'0']),
                                      'T': TrainingData.str2double(row['T']), 
                                      'I': TrainingData.str2double(row['I']),
                                      'pH': TrainingData.str2double(row['pH']),
                                      'pMg': TrainingData.str2double(row['pMg']),
                                      'weight': weight,
                                      'balance': False,
                                      'reference': row['compound_ref'],
                                      'description': row['name'] + ' formation'})

        logging.debug('Successfully added %d formation energies' % len(thermo_params))
        return thermo_params, cids_that_dont_decompose
        
    @staticmethod
    def read_redox(fname, weight):
        """Read the Reduction potential data"""
        # columns are: reaction, dG'0, T, I, pH, pMg, weight, balance?
        thermo_params = []
        
        # fields are: name, CID_ox, nH_ox, charge_ox, CID_red,
        #             nH_red, charge_red, E'0, pH, I, pMg, T, ref
        for row in csv.DictReader(open(fname, 'r'), delimiter='\t'):
            delta_nH = TrainingData.str2double(row['nH_red']) - \
                       TrainingData.str2double(row['nH_ox'])
            delta_charge = TrainingData.str2double(row['charge_red']) - \
                           TrainingData.str2double(row['charge_ox'])
            delta_e = delta_nH - delta_charge
            dG0_prime = -F * TrainingData.str2double(row['E\'0']) * delta_e
            rxn = KeggReaction({row['CID_ox'] : -1, row['CID_red'] : 1})
            thermo_params.append({'reaction': rxn,
                                  'dG\'0' : dG0_prime,
                                  'T': TrainingData.str2double(row['T']), 
                                  'I': TrainingData.str2double(row['I']),
                                  'pH': TrainingData.str2double(row['pH']),
                                  'pMg': TrainingData.str2double(row['pMg']),
                                  'weight': weight,
                                  'balance': False,
                                  'reference': row['ref'],        
                                  'description': row['name'] + ' redox'})

        logging.debug('Successfully added %d redox potentials' % len(thermo_params))
        return thermo_params
    
    @staticmethod
    def get_all_thermo_params():
        base_path = os.path.split(os.path.realpath(__file__))[0]
    
        fname, weight = TrainingData.FNAME_DICT['TECRDB']
        fname = os.path.join(base_path, fname)
        tecrdb_params = TrainingData.read_tecrdb(fname, weight)
        
        fname, weight = TrainingData.FNAME_DICT['FORMATION']
        fname = os.path.join(base_path, fname)
        formation_params, cids_that_dont_decompose = TrainingData.read_formations(fname, weight)
        
        fname, weight = TrainingData.FNAME_DICT['REDOX']
        fname = os.path.join(base_path, fname)
        redox_params = TrainingData.read_redox(fname, weight)
        
        thermo_params = tecrdb_params + formation_params + redox_params
        return thermo_params, cids_that_dont_decompose
    
    def balance_reactions(self, rxn_inds_to_balance):
        """
            use the chemical formulas from the InChIs to verify that each and every
            reaction is balanced
        """
        elements, Ematrix = self.ccache.get_element_matrix(self.cids)
        cpd_inds_without_formula = list(np.nonzero(np.any(np.isnan(Ematrix), 1))[0].flat)
        Ematrix[np.isnan(Ematrix)] = 0

        S_without_formula = self.S[cpd_inds_without_formula, :]
        rxn_inds_without_formula = np.nonzero(np.any(S_without_formula != 0, 0))[0]
        rxn_inds_to_balance = set(rxn_inds_to_balance).difference(rxn_inds_without_formula)

        # need to check that all elements are balanced (except H, but including e-)
        # if only O is not balanced, add water molecules
        if 'O' in elements:
            i_H2O = self.cids.index('C00001')
            j_O = elements.index('O')
            conserved = np.dot(Ematrix.T, self.S)
            for k in rxn_inds_to_balance:
                self.S[i_H2O, k] = self.S[i_H2O, k] - conserved[j_O, k]

        # recalculate conservation matrix
        conserved = Ematrix.T * self.S
        
        rxn_inds_to_remove = [k for k in rxn_inds_to_balance 
                              if np.any(conserved[:, k] != 0, 0)]
        
        for k in rxn_inds_to_remove:
            sprs = {}
            for i in np.nonzero(self.S[:, k])[0]:
                sprs[self.cids[i]] = self.S[i, k]
            reaction = KeggReaction(sprs)
            logging.debug('unbalanced reaction #%d: %s' %
                          (k, reaction.write_formula()))
            for j in np.where(conserved[:, k])[0].flat:
                logging.debug('there are %d more %s atoms on the right-hand side' %
                              (conserved[j, k], elements[j]))
        
        rxn_inds_to_keep = \
            set(range(self.S.shape[1])).difference(rxn_inds_to_remove)
        
        rxn_inds_to_keep = sorted(rxn_inds_to_keep)
        
        self.S = self.S[:, rxn_inds_to_keep]
        self.dG0_prime = self.dG0_prime[rxn_inds_to_keep]
        self.T = self.T[rxn_inds_to_keep]
        self.I = self.I[rxn_inds_to_keep]
        self.pH = self.pH[rxn_inds_to_keep]
        self.pMg = self.pMg[rxn_inds_to_keep]
        self.weight = self.weight[rxn_inds_to_keep]
        self.reference = [self.reference[i] for i in rxn_inds_to_keep]
        self.description = [self.description[i] for i in rxn_inds_to_keep]

        logging.debug('After removing %d unbalanced reactions, the stoichiometric '
                      'matrix contains: '
                      '%d compounds and %d reactions' %
                      (len(rxn_inds_to_remove), self.S.shape[0], self.S.shape[1]))

    def reverse_transform(self):
        """
            Calculate the reverse transform for all reactions in training_data.
        """
        n_rxns = self.S.shape[1]
        reverse_ddG0 = np.zeros(n_rxns)
        self.I[np.isnan(self.I)] = 0.25 # default ionic strength is 0.25M
        self.pMg[np.isnan(self.pMg)] = 14 # default pMg is 14
        for i in xrange(n_rxns):
            for j in np.nonzero(self.S[:, i])[0]:
                cid = self.cids[j]
                if cid == 'C00080': # H+ should be ignored in the Legendre transform
                    continue
                comp = self.ccache.get_compound(cid)
                ddG0 = comp.transform_pH7(self.pH[i], self.I[i], self.T[i])
                reverse_ddG0[i] = reverse_ddG0[i] + ddG0 * self.S[j, i]

        self.dG0 = self.dG0_prime - reverse_ddG0
class KeggModel(object):
    def __del__(self):
        self.ccache.dump()

    def __init__(self, S, cids, rids=None):
        self.S = S
        self.cids = cids
        self.rids = rids
        assert len(self.cids) == self.S.shape[0]
        if self.rids is not None:
            assert len(self.rids) == self.S.shape[1]
        self.ccache = CompoundCacher()

        # remove H+ from the stoichiometric matrix if it exists
        if 'C00080' in self.cids:
            i = self.cids.index('C00080')
            self.S = np.vstack((self.S[:i, :], self.S[i + 1:, :]))
            self.cids.pop(i)

    @staticmethod
    def from_file(fname, arrow='<=>', format='kegg', has_reaction_ids=False):
        """
        reads a file containing reactions in KEGG format
        
        Arguments:
           fname            - the filename to read
           arrow            - the string used as the 'arrow' in each reaction (default: '<=>')
           format           - the text file format provided ('kegg', 'tsv' or 'csv')
           has_reaction_ids - a boolean flag indicating if there is a column of
                              reaction IDs (separated from the reaction with
                              whitespaces)
        
        Return a KeggModel
        """
        fd = open(fname, 'r')
        if format == 'kegg':
            model = KeggModel.from_formulas(fd.readlines(), arrow,
                                            has_reaction_ids)
        elif format == 'tsv':
            model = KeggModel.from_csv(fd,
                                       has_reaction_ids=has_reaction_ids,
                                       delimiter='\t')
        elif format == 'csv':
            model = KeggModel.from_csv(fd,
                                       has_reaction_ids=has_reaction_ids,
                                       delimiter=None)
        fd.close()
        return model

    @staticmethod
    def from_csv(fd, has_reaction_ids=True, delimiter=None):
        csv_reader = csv.reader(fd, delimiter=delimiter)
        if has_reaction_ids:
            rids = csv_reader.next()
            rids = rids[1:]
        else:
            rids = None
        S = []
        cids = []
        for i, row in enumerate(csv_reader):
            cids.append(row[0])
            S.append([float(x) for x in row[1:]])
        S = np.array(S)

        return KeggModel(S, cids, rids)

    @staticmethod
    def from_kegg_reactions(kegg_reactions, has_reaction_ids=False):
        if has_reaction_ids:
            rids = [r.rid for r in kegg_reactions]
        else:
            rids = None

        cids = set()
        for reaction in kegg_reactions:
            cids = cids.union(reaction.keys())

        # convert the list of reactions in sparse notation into a full
        # stoichiometric matrix, where the rows (compounds) are according to the
        # CID list 'cids'.
        cids = sorted(cids)
        S = np.matrix(np.zeros((len(cids), len(kegg_reactions))))
        for i, reaction in enumerate(kegg_reactions):
            S[:, i] = np.matrix(reaction.dense(cids))

        logging.debug(
            'Successfully loaded %d reactions (involving %d unique compounds)'
            % (S.shape[1], S.shape[0]))
        return KeggModel(S, cids, rids)

    @staticmethod
    def from_formulas(reaction_strings,
                      arrow='<=>',
                      has_reaction_ids=False,
                      raise_exception=False):
        """
        parses a list of reactions in KEGG format
        
        Arguments:
           reaction_strings - a list of reactions in KEGG format
           arrow            - the string used as the 'arrow' in each reaction (default: '<=>')
           has_reaction_ids - a boolean flag indicating if there is a column of
                              reaction IDs (separated from the reaction with
                              whitespaces)
        
        Return values:
           S     - a stoichiometric matrix
           cids  - the KEGG compound IDs in the same order as the rows of S
        """
        try:
            reactions = []
            not_balanced_count = 0
            for line in reaction_strings:
                rid = None
                if has_reaction_ids:
                    tokens = re.findall('(\w+)\s+(.*)', line.strip())[0]
                    rid = tokens[0]
                    line = tokens[1]
                try:
                    reaction = KeggReaction.parse_formula(line, arrow, rid)
                except KeggParseException as e:
                    logging.warning(str(e))
                    reaction = KeggReaction({})
                if not reaction.is_balanced(fix_water=True,
                                            raise_exception=raise_exception):
                    not_balanced_count += 1
                    logging.warning('Model contains an unbalanced reaction: ' +
                                    line)
                    reaction = KeggReaction({})
                reactions.append(reaction)
                logging.debug('Adding reaction: ' + reaction.write_formula())

            if not_balanced_count > 0:
                warning_str = '%d out of the %d reactions are not chemically balanced' % \
                              (not_balanced_count, len(reaction_strings))
                logging.debug(warning_str)
            return KeggModel.from_kegg_reactions(reactions, has_reaction_ids)

        except ValueError as e:
            if raise_exception:
                raise e
            else:
                logging.debug(str(e))
                return None

    def add_thermo(self, cc):
        # check that all CIDs in the reaction are already cached by CC
        Nc, Nr = self.S.shape
        reactions = []
        for j in xrange(Nr):
            sparse = {
                self.cids[i]: self.S[i, j]
                for i in xrange(Nc) if self.S[i, j] != 0
            }
            reaction = KeggReaction(sparse)
            reactions.append(reaction)

        self.dG0, self.cov_dG0 = cc.get_dG0_r_multi(reactions)

    def get_transformed_dG0(self, pH, I, T):
        """
            returns the estimated dG0_prime and the standard deviation of
            each estimate (i.e. a measure for the uncertainty).
        """
        dG0_prime = self.dG0 + self._get_transform_ddG0(pH=pH, I=I, T=T)
        dG0_std = np.matrix(np.sqrt(np.diag(self.cov_dG0))).T
        U, s, V = np.linalg.svd(self.cov_dG0, full_matrices=True)
        sqrt_Sigma = np.matrix(U) * np.matrix(np.diag(s**0.5)) * np.matrix(V)
        return dG0_prime, dG0_std, sqrt_Sigma

    def _get_transform_ddG0(self, pH, I, T):
        """
        needed in order to calculate the transformed Gibbs energies of the 
        model reactions.
        
        Returns:
            an array (whose length is self.S.shape[1]) with the differences
            between DrG0_prime and DrG0. Therefore, one must add this array
            to the chemical Gibbs energies of reaction (DrG0) to get the 
            transformed values
        """
        ddG0_compounds = np.matrix(np.zeros((self.S.shape[0], 1)))
        for i, cid in enumerate(self.cids):
            comp = self.ccache.get_compound(cid)
            ddG0_compounds[i, 0] = comp.transform_pH7(pH, I, T)

        ddG0_forward = np.dot(self.S.T, ddG0_compounds)
        return ddG0_forward

    def check_S_balance(self, fix_water=False):
        elements, Ematrix = self.ccache.get_element_matrix(self.cids)
        conserved = Ematrix.T * self.S

        if fix_water:
            # This part only looks for imbalanced oxygen and uses extra
            # H2O molecules (on either side of the reaction equation) to
            # balance them. Keep in mind that also the e- balance is affected
            # by the water (and hydrogen is not counted at all).
            if 'C00001' not in self.cids:
                self.S = np.vstack([self.S, np.zeros((1, self.S.shape[1]))])
                self.cids.append('C00001')
                elements, Ematrix = self.ccache.get_element_matrix(self.cids)

            i_h2o = self.cids.index('C00001')
            add_water = -conserved[elements.index('O'), :]
            self.S[i_h2o, :] += add_water
            conserved += Ematrix[i_h2o, :].T * add_water

        rxnFil = np.any(conserved[:, range(self.S.shape[1])], axis=0)
        unbalanced_ind = np.nonzero(rxnFil)[1]
        if unbalanced_ind != []:
            logging.warning('There are (%d) unbalanced reactions in S. '
                            'Setting their coefficients to 0.' %
                            len(unbalanced_ind.flat))
            if self.rids is not None:
                logging.warning(
                    'These are the unbalanced reactions: ' +
                    ', '.join([self.rids[i] for i in unbalanced_ind.flat]))

            self.S[:, unbalanced_ind] = 0
        return self

    def write_reaction_by_index(self, r):
        sparse = dict([(cid, self.S[i, r]) for i, cid in enumerate(self.cids)
                       if self.S[i, r] != 0])
        if self.rids is not None:
            reaction = KeggReaction(sparse, rid=self.rids[r])
        else:
            reaction = KeggReaction(sparse)
        return reaction.write_formula()

    def get_unidirectional_S(self):
        S_plus = np.copy(self.S)
        S_minus = np.copy(self.S)
        S_plus[self.S < 0] = 0
        S_minus[self.S > 0] = 0
        return S_minus, S_plus
class KeggReaction(object):

    def __init__(self, sparse, arrow='<=>', rid=None):
        for cid, coeff in sparse.iteritems():
            if not (isinstance(coeff, float) or isinstance(coeff, int)):
                raise ValueError('All values in KeggReaction must be integers or floats')
        self.sparse = dict(filter(lambda (k,v):v, sparse.items()))
        self.arrow = arrow
        self.rid = rid
        self.ccache = CompoundCacher()

    def keys(self):
        return self.sparse.keys()
        
    def iteritems(self):
        return self.sparse.iteritems()

    def __str__(self):
        return self.write_formula()

    def reverse(self):
        """
            reverse the direction of the reaction by negating all stoichiometric
            coefficients
        """
        self.sparse = dict( (k, -v) for (k, v) in self.sparse.iteritems() )

    @staticmethod
    def parse_reaction_formula_side(s):
        """ 
            Parses the side formula, e.g. '2 C00001 + C00002 + 3 C00003'
            Ignores stoichiometry.
            
            Returns:
                The set of CIDs.
        """
        if s.strip() == "null":
            return {}
        
        compound_bag = {}
        for member in re.split('\s+\+\s+', s):
            tokens = member.split(None, 1)
            if len(tokens) == 0:
                continue
            if len(tokens) == 1:
                amount = 1
                key = member
            else:
                try:
                    amount = float(tokens[0])
                except ValueError:
                    raise KeggParseException(
                        "Non-specific reaction: %s" % s)
                key = tokens[1]
                
            try:
                compound_bag[key] = compound_bag.get(key, 0) + amount
            except ValueError:
                raise KeggParseException(
                    "Non-specific reaction: %s" % s)
        
        return compound_bag

    @staticmethod
    def parse_formula(formula, arrow='<=>', rid=None):
        """ 
            Parses a two-sided formula such as: 2 C00001 => C00002 + C00003 
            
            Return:
                The set of substrates, products and the direction of the reaction
        """
        tokens = formula.split(arrow)
        if len(tokens) < 2:
            raise KeggParseException('Reaction does not contain the arrow sign (%s): %s'
                                     % (arrow, formula))
        if len(tokens) > 2:
            raise KeggParseException('Reaction contains more than one arrow sign (%s): %s'
                                     % (arrow, formula))
        
        left = tokens[0].strip()
        right = tokens[1].strip()
        
        sparse_reaction = {}
        for cid, count in KeggReaction.parse_reaction_formula_side(left).iteritems():
            sparse_reaction[cid] = sparse_reaction.get(cid, 0) - count 

        for cid, count in KeggReaction.parse_reaction_formula_side(right).iteritems():
            sparse_reaction[cid] = sparse_reaction.get(cid, 0) + count 

        return KeggReaction(sparse_reaction, arrow, rid=rid)

    @staticmethod
    def write_compound_and_coeff(compound_id, coeff):
        if coeff == 1:
            return compound_id
        else:
            return "%g %s" % (coeff, compound_id)

    def write_formula(self):
        """String representation."""
        left = []
        right = []
        for cid, coeff in sorted(self.sparse.iteritems()):
            if coeff < 0:
                left.append(KeggReaction.write_compound_and_coeff(cid, -coeff))
            elif coeff > 0:
                right.append(KeggReaction.write_compound_and_coeff(cid, coeff))
        return "%s %s %s" % (' + '.join(left), self.arrow, ' + '.join(right))

    def _get_reaction_atom_bag(self, raise_exception=False):
        """
            Use for checking if all elements are conserved.
            
            Returns:
                An atom_bag of the differences between the sides of the reaction.
                E.g. if there is one extra C on the left-hand side, the result will
                be {'C': -1}.
        """
        try:
            cids = list(self.keys())
            coeffs = map(self.sparse.__getitem__, cids)
            coeffs = np.matrix(coeffs)
    
            cached_cids = set(map(str, self.ccache.compound_id2inchi.keys()))
            if not cached_cids.issuperset(cids):
                missing_cids = set(cids).difference(cached_cids)
                warning_str = 'The following compound IDs are not in the cache, ' + \
                              'make sure they appear in kegg_additions.tsv and ' + \
                              'then run compound_cacher.py: ' + \
                              ', '.join(sorted(missing_cids))
                raise ValueError(warning_str)
        
            elements, Ematrix = self.ccache.get_element_matrix(cids)
            conserved = coeffs * Ematrix
    
            if np.any(np.isnan(conserved), 1):
                warning_str = 'cannot test reaction balancing because of unspecific ' + \
                              'compound formulas: %s' % self.write_formula()
                raise ValueError(warning_str)
            
            atom_bag = {}        
            if np.any(conserved != 0, 1):
                logging.debug('unbalanced reaction: %s' % self.write_formula())
                for j, c in enumerate(conserved.flat):
                    if c != 0:
                        logging.debug('there are %d more %s atoms on the right-hand side' %
                                      (c, elements[j]))
                        atom_bag[str(elements[j])] = c
            return atom_bag
            
        except ValueError as e:
            if raise_exception:
                raise e
            else:
                logging.debug(str(e))
                return None

    def is_balanced(self, fix_water=False, raise_exception=False):
        reaction_atom_bag = self._get_reaction_atom_bag(raise_exception)

        if reaction_atom_bag is None: # this means some compound formulas are missing
            return False

        if fix_water and 'O' in reaction_atom_bag:
            self.sparse.setdefault('C00001', 0)
            self.sparse['C00001'] += -reaction_atom_bag['O']
            if self.sparse['C00001'] == 0:
                del self.sparse['C00001']
            reaction_atom_bag = self._get_reaction_atom_bag()

        return len(reaction_atom_bag) == 0

    def is_empty(self):
        return len(self.sparse) == 0
            
    def dense(self, cids):
        s = np.matrix(np.zeros((len(cids), 1)))
        for cid, coeff in self.iteritems():
            s[cids.index(cid), 0] = coeff
        return s

    def get_transform_ddG0(self, pH, I, T):
        """
        needed in order to calculate the transformed Gibbs energies of
        reactions.
        
        Returns:
            The difference between DrG0_prime and DrG0 for this reaction.
            Therefore, this value must be added to the chemical Gibbs
            energy of reaction (DrG0) to get the transformed value.
        """
        ddG0_forward = 0
        for compound_id, coeff in self.iteritems():
            comp = self.ccache.get_compound(compound_id)
            ddG0_forward += coeff * comp.transform_pH7(pH, I, T)
        return ddG0_forward
class KeggReaction(object):
    def __init__(self, sparse, arrow='<=>', rid=None):
        for cid, coeff in sparse.iteritems():
            if not (isinstance(coeff, float) or isinstance(coeff, int)):
                raise ValueError(
                    'All values in KeggReaction must be integers or floats')
        self.sparse = dict(filter(lambda (k, v): v, sparse.items()))
        self.arrow = arrow
        self.rid = rid
        self.ccache = CompoundCacher()

    def keys(self):
        return self.sparse.keys()

    def iteritems(self):
        return self.sparse.iteritems()

    def __str__(self):
        return self.write_formula()

    def reverse(self):
        """
            reverse the direction of the reaction by negating all stoichiometric
            coefficients
        """
        self.sparse = dict((k, -v) for (k, v) in self.sparse.iteritems())

    @staticmethod
    def parse_reaction_formula_side(s):
        """ 
            Parses the side formula, e.g. '2 C00001 + C00002 + 3 C00003'
            Ignores stoichiometry.
            
            Returns:
                The set of CIDs.
        """
        if s.strip() == "null":
            return {}

        compound_bag = {}
        for member in re.split('\s+\+\s+', s):
            tokens = member.split(None, 1)
            if len(tokens) == 0:
                continue
            if len(tokens) == 1:
                amount = 1
                key = member
            else:
                try:
                    amount = float(tokens[0])
                except ValueError:
                    raise KeggParseException("Non-specific reaction: %s" % s)
                key = tokens[1]

            try:
                compound_bag[key] = compound_bag.get(key, 0) + amount
            except ValueError:
                raise KeggParseException("Non-specific reaction: %s" % s)

        return compound_bag

    @staticmethod
    def parse_formula(formula, arrow='<=>', rid=None):
        """ 
            Parses a two-sided formula such as: 2 C00001 => C00002 + C00003 
            
            Return:
                The set of substrates, products and the direction of the reaction
        """
        tokens = formula.split(arrow)
        if len(tokens) < 2:
            raise KeggParseException(
                'Reaction does not contain the arrow sign (%s): %s' %
                (arrow, formula))
        if len(tokens) > 2:
            raise KeggParseException(
                'Reaction contains more than one arrow sign (%s): %s' %
                (arrow, formula))

        left = tokens[0].strip()
        right = tokens[1].strip()

        sparse_reaction = {}
        for cid, count in KeggReaction.parse_reaction_formula_side(
                left).iteritems():
            sparse_reaction[cid] = sparse_reaction.get(cid, 0) - count

        for cid, count in KeggReaction.parse_reaction_formula_side(
                right).iteritems():
            sparse_reaction[cid] = sparse_reaction.get(cid, 0) + count

        return KeggReaction(sparse_reaction, arrow, rid=rid)

    @staticmethod
    def write_compound_and_coeff(compound_id, coeff):
        if coeff == 1:
            return compound_id
        else:
            return "%g %s" % (coeff, compound_id)

    def write_formula(self):
        """String representation."""
        left = []
        right = []
        for cid, coeff in sorted(self.sparse.iteritems()):
            if coeff < 0:
                left.append(KeggReaction.write_compound_and_coeff(cid, -coeff))
            elif coeff > 0:
                right.append(KeggReaction.write_compound_and_coeff(cid, coeff))
        return "%s %s %s" % (' + '.join(left), self.arrow, ' + '.join(right))

    def _get_reaction_atom_bag(self, raise_exception=False):
        """
            Use for checking if all elements are conserved.
            
            Returns:
                An atom_bag of the differences between the sides of the reaction.
                E.g. if there is one extra C on the left-hand side, the result will
                be {'C': -1}.
        """
        try:
            cids = list(self.keys())
            coeffs = map(self.sparse.__getitem__, cids)
            coeffs = np.matrix(coeffs)

            cached_cids = set(map(str, self.ccache.compound_id2inchi.keys()))
            if not cached_cids.issuperset(cids):
                missing_cids = set(cids).difference(cached_cids)
                warning_str = 'The following compound IDs are not in the cache, ' + \
                              'make sure they appear in kegg_additions.tsv and ' + \
                              'then run compound_cacher.py: ' + \
                              ', '.join(sorted(missing_cids))
                raise ValueError(warning_str)

            elements, Ematrix = self.ccache.get_element_matrix(cids)
            conserved = coeffs * Ematrix

            if np.any(np.isnan(conserved), 1):
                warning_str = 'cannot test reaction balancing because of unspecific ' + \
                              'compound formulas: %s' % self.write_formula()
                raise ValueError(warning_str)

            atom_bag = {}
            if np.any(conserved != 0, 1):
                logging.debug('unbalanced reaction: %s' % self.write_formula())
                for j, c in enumerate(conserved.flat):
                    if c != 0:
                        logging.debug(
                            'there are %d more %s atoms on the right-hand side'
                            % (c, elements[j]))
                        atom_bag[str(elements[j])] = c
            return atom_bag

        except ValueError as e:
            if raise_exception:
                raise e
            else:
                logging.debug(str(e))
                return None

    def is_balanced(self, fix_water=False, raise_exception=False):
        reaction_atom_bag = self._get_reaction_atom_bag(raise_exception)

        if reaction_atom_bag is None:  # this means some compound formulas are missing
            return False

        if fix_water and 'O' in reaction_atom_bag:
            self.sparse.setdefault('C00001', 0)
            self.sparse['C00001'] += -reaction_atom_bag['O']
            if self.sparse['C00001'] == 0:
                del self.sparse['C00001']
            reaction_atom_bag = self._get_reaction_atom_bag()

        return len(reaction_atom_bag) == 0

    def is_empty(self):
        return len(self.sparse) == 0

    def dense(self, cids):
        s = np.matrix(np.zeros((len(cids), 1)))
        for cid, coeff in self.iteritems():
            s[cids.index(cid), 0] = coeff
        return s

    def get_transform_ddG0(self, pH, I, T):
        """
        needed in order to calculate the transformed Gibbs energies of
        reactions.
        
        Returns:
            The difference between DrG0_prime and DrG0 for this reaction.
            Therefore, this value must be added to the chemical Gibbs
            energy of reaction (DrG0) to get the transformed value.
        """
        ddG0_forward = 0
        for compound_id, coeff in self.iteritems():
            comp = self.ccache.get_compound(compound_id)
            ddG0_forward += coeff * comp.transform_pH7(pH, I, T)
        return ddG0_forward