class FisherRunner(object):
    def __init__(self):
        self.data_class = JointData
        
        self.classes = ('Reference', 'Germline', 'Somatic', 'LOH', 'Unknown')
    
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = csv.writer(open(args.tsv_file_name, 'w'), delimiter='\t')
        
        chr_list = self.reader.get_chr_list()
        
        for chr_name in sorted(chr_list):
            self._classify_chromosome(chr_name)
                            
        self.reader.close()
        
    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        
        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
                              
            data = self.data_class(sub_counts)            
                
            labels = self.model.classify(data)
            
            self._write_rows(chr_name, sub_rows, labels)
        
            start = stop
            stop = min(stop + n, end)

    def _write_rows(self, chr_name, rows, labels):
        for i, row in enumerate(rows):
            out_row = [chr_name]
            out_row.extend(row)
            
            label = int(labels[i])
            
            class_name = self.classes[label]
                        
            out_row.append(class_name)
            
            if class_name == 'Somatic':
                print out_row
            
            self.writer.writerow(out_row)
class JointModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
                    
    def _train(self, args):        
        if args.subsample_size > 0:
            counts = self._subsample(args.subsample_size)
        else:
            counts = self.reader.get_counts()
                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        data = JointData(counts)
        
        self.parameters = self.model.train(data, self.priors,
                                            args.max_iters, args.convergence_threshold)

    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
                              
            data = JointData(sub_counts)            
                
            resp = self.model.classify(data, self.parameters)
        
            self.writer.write_data(chr_name, sub_rows, resp)
            
            start = stop
            stop = min(stop + n, end)
class ChromosomeModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
    
    def _train(self, args):                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        chr_list = self.reader.get_chr_list()
        
        self.parameters = {}
        
        for chr_name in sorted(chr_list):
            print chr_name
            
            if args.subsample_size > 0:
                counts = self._chrom_subsample(chr_name, args.subsample_size)
            else:        
                counts = self.reader.get_counts(chr_name)
            
            data = self.data_class(counts)
            
            self.parameters[chr_name] = self.model.train(data, self.priors,
                                                          args.max_iters, args.convergence_threshold)
                        
    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
                              
            data = self.data_class(sub_counts)            
                
            resp = self.model.classify(data, self.parameters[chr_name])
        
            self.writer.write_data(chr_name, sub_rows, resp)
            
            start = stop
            stop = min(stop + n, end)

    def _chrom_subsample(self, chr_name, sample_size):
        chr_size = self.reader.get_chr_size(chr_name=chr_name)
        
        sample_size = min(chr_size, sample_size)
        
        chr_sample_indices = random.sample(xrange(chr_size), sample_size)
        
        chr_counts = self.reader.get_counts(chr_name)
        
        chr_sample = chr_counts[chr_sample_indices]
        
        return chr_sample
class IndependentModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
                 
    def _train(self, args):
        if args.subsample_size > 0:
            counts = self._subsample(args.subsample_size)
        else:
            counts = self.reader.get_counts()
                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        self.parameters = {}
        
        for genome in constants.genomes:
            data = IndependentData(counts, genome)
            
            self.parameters[genome] = self.model.train(data, self.priors[genome],
                                                        args.max_iters, args.convergence_threshold)
                                    
    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
            
            indep_resp = {}
            
            for genome in constants.genomes:                          
                data = IndependentData(sub_counts, genome)            
                
                indep_resp[genome] = self.model.classify(data, self.parameters[genome])
            
            joint_resp = self._get_joint_responsibilities(indep_resp)
        
            self.writer.write_data(chr_name, sub_rows, joint_resp)
            
            start = stop
            stop = min(stop + n, end)
            
    def _get_joint_responsibilities(self, resp):
        normal_resp = np.log(resp['normal'])
        tumour_resp = np.log(resp['tumour'])
        
        n = normal_resp.shape[0]
        
        nclass_normal = normal_resp.shape[1] 
        
        column_shape = (n, 1)
        
        log_resp = []
        
        for i in range(nclass_normal): 
            log_resp.append(normal_resp[:, i].reshape(column_shape) + tumour_resp)
        
        log_resp = np.hstack(log_resp)
        
        resp = np.exp(log_resp)
        
        return resp