class JointModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
                    
    def _train(self, args):        
        if args.subsample_size > 0:
            counts = self._subsample(args.subsample_size)
        else:
            counts = self.reader.get_counts()
                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        data = JointData(counts)
        
        self.parameters = self.model.train(data, self.priors,
                                            args.max_iters, args.convergence_threshold)

    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
                              
            data = JointData(sub_counts)            
                
            resp = self.model.classify(data, self.parameters)
        
            self.writer.write_data(chr_name, sub_rows, resp)
            
            start = stop
            stop = min(stop + n, end)
 def run(self, args):
     self.reader = JointCountsReader(args.jcnt_file_name)
     self.writer = JointSnvMixWriter(args.jsm_file_name)
     
     ModelRunner.run(self, args)
class ChromosomeModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
    
    def _train(self, args):                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        chr_list = self.reader.get_chr_list()
        
        self.parameters = {}
        
        for chr_name in sorted(chr_list):
            print chr_name
            
            if args.subsample_size > 0:
                counts = self._chrom_subsample(chr_name, args.subsample_size)
            else:        
                counts = self.reader.get_counts(chr_name)
            
            data = self.data_class(counts)
            
            self.parameters[chr_name] = self.model.train(data, self.priors,
                                                          args.max_iters, args.convergence_threshold)
                        
    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
                              
            data = self.data_class(sub_counts)            
                
            resp = self.model.classify(data, self.parameters[chr_name])
        
            self.writer.write_data(chr_name, sub_rows, resp)
            
            start = stop
            stop = min(stop + n, end)

    def _chrom_subsample(self, chr_name, sample_size):
        chr_size = self.reader.get_chr_size(chr_name=chr_name)
        
        sample_size = min(chr_size, sample_size)
        
        chr_sample_indices = random.sample(xrange(chr_size), sample_size)
        
        chr_counts = self.reader.get_counts(chr_name)
        
        chr_sample = chr_counts[chr_sample_indices]
        
        return chr_sample
class IndependentModelRunner(ModelRunner):
    def run(self, args):
        self.reader = JointCountsReader(args.jcnt_file_name)
        self.writer = JointSnvMixWriter(args.jsm_file_name)
        
        ModelRunner.run(self, args)
                 
    def _train(self, args):
        if args.subsample_size > 0:
            counts = self._subsample(args.subsample_size)
        else:
            counts = self.reader.get_counts()
                   
        self.priors_parser.load_from_file(args.priors_file)
        self.priors = self.priors_parser.to_dict()
        
        self._write_priors()
        
        self.parameters = {}
        
        for genome in constants.genomes:
            data = IndependentData(counts, genome)
            
            self.parameters[genome] = self.model.train(data, self.priors[genome],
                                                        args.max_iters, args.convergence_threshold)
                                    
    def _classify_chromosome(self, chr_name):
        counts = self.reader.get_counts(chr_name)
        jcnt_rows = self.reader.get_rows(chr_name)
        
        end = self.reader.get_chr_size(chr_name)

        n = int(1e5)
        start = 0
        stop = min(n, end)
        

        while start < end:
            sub_counts = counts[start:stop]
            sub_rows = jcnt_rows[start:stop]
            
            indep_resp = {}
            
            for genome in constants.genomes:                          
                data = IndependentData(sub_counts, genome)            
                
                indep_resp[genome] = self.model.classify(data, self.parameters[genome])
            
            joint_resp = self._get_joint_responsibilities(indep_resp)
        
            self.writer.write_data(chr_name, sub_rows, joint_resp)
            
            start = stop
            stop = min(stop + n, end)
            
    def _get_joint_responsibilities(self, resp):
        normal_resp = np.log(resp['normal'])
        tumour_resp = np.log(resp['tumour'])
        
        n = normal_resp.shape[0]
        
        nclass_normal = normal_resp.shape[1] 
        
        column_shape = (n, 1)
        
        log_resp = []
        
        for i in range(nclass_normal): 
            log_resp.append(normal_resp[:, i].reshape(column_shape) + tumour_resp)
        
        log_resp = np.hstack(log_resp)
        
        resp = np.exp(log_resp)
        
        return resp