class TestingController():
    def __init__(self, config, experiment_name, model, new_testing):
        test_config = config['testing']
        model_config = config['model']
        
        self.model = model
        self.reads = test_config['reads']
        self.batch_size = test_config['batch_size']

        self.use_assembler = test_config['signal_window_stride'] < model_config['signal_window_size']
        self.save_predictions = test_config['save_predictions']
        self.inference_controller = InferenceController()

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_testing else self.file_controller.load_testing()

    def pretty_print_progress(self, start, end, total):
        try:
            progress_str = '['
            for i in range(0, total, max(total//50, 1)):
                if i >= start and i < end:
                    progress_str += 'x'
                else:
                    progress_str += '-'
            progress_str += ']'
            return progress_str
        except:
            return '[ Failed to get progress string ]'

    def get_assembly(self, y_pred, iteration, read_id, bacteria):
        if self.use_assembler == False:
            return ''.join(y_pred)
        assembly_path = self.file_controller.get_assembly_filepath(iteration, read_id, bacteria)
        return assemble_and_output(assembly_path, y_pred)

    def get_result(self, assembly, aligner, read_id, bacteria):
        try:
            besthit = next(aligner.map(assembly))
            cigacc = 1-(besthit.NM/besthit.blen)
            return self.get_result_dict(read_id, bacteria, besthit.ctg, besthit.r_st, besthit.r_en, besthit.NM, besthit.blen, besthit.cigar_str, cigacc)
        except:
            return self.get_result_dict(read_id, bacteria, 0, 0, 0, 0, 0, 0, 0)

    def save_prediction(self, prediction_str, bacteria, read_id, iteration):
        if self.save_predictions == False:
            return
        self.file_controller.save_prediction(prediction_str, bacteria, read_id, iteration)

    def get_result_dict(self, read_id, bacteria, ctg, r_st, r_en, nm, blen, cig, cigacc):
        return {
            'read_id':read_id,
            'bacteria':bacteria,
            'ctg': ctg,
            'r_st': r_st,
            'r_en': r_en,
            'NM': nm,
            'blen': blen,
            'cig': cig,
            'cigacc': cigacc
        }

    def test(self, bacteria, generator, aligner):
        print(f' - Testing {bacteria}.')
        for i in range(self.reads):
            try:
                x, read_id = next(generator.get_batched_read())
                start_time = time.time()
                y_pred = []
                for b in range(0, len(x), self.batch_size):
                    x_batch = x[b:b+self.batch_size]
                    progress_bar = self.pretty_print_progress(b, b+len(x_batch), len(x))
                    print(f"{i+1:02d}/{self.reads:02d} Predicting batch {progress_bar} {b:04d}-{b+len(x_batch):04d}/{len(x):04d}", end="\r")
                                       
                    y_batch_pred = self.inference_controller.predict_batch(x_batch, self.model)
                    y_batch_pred_strings = convert_to_base_strings(y_batch_pred)
                    y_pred.extend(y_batch_pred_strings)
                
                assembly = self.get_assembly(y_pred, i, read_id, bacteria)
                self.save_prediction(assembly, bacteria, read_id, i)
                
                result = self.get_result(assembly, aligner, read_id, bacteria)
                result['time'] = time.time() - start_time
                self.results.append(result)

                print(f"{i:02d}/{self.reads} Done | CIG ACC: {result['cigacc']}"+" "*70) # 70 blanks to overwrite the previous print
                self.file_controller.save_testing(self.results)
            except Exception as e:
                print(e)
                traceback.print_exc()
Ejemplo n.º 2
0
class EvaluationController():
    def __init__(self, experiment_name):
        self.file_controller = FileController(experiment_name)

    def count_unmatched_reads(self):
        data = self.get_accuracy_list(include_unmatched=True)
        count = 0
        for accuracy in data:
            if accuracy == 0:
                count += 1
        return count

    def count_unmatched_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            count = 0
            for accuracy in data[bacteria]:
                if accuracy == 0:
                    count += 1
            counts[bacteria] = count
        return counts

    def count_reads(self):
        data = self.file_controller.load_testing()
        return len(data)

    def count_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            counts[bacteria] = len(data[bacteria])
        return counts

    def get_accuracy_list(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracy = []
        for measurement in data:
            if include_unmatched == False and measurement['cigacc'] == 0:
                continue
            accuracy.append(measurement['cigacc'] * 100)
        return accuracy

    def get_accuracy_list_per_bacteria(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracies = {}
        for measurement in data:
            key = measurement['bacteria']
            if key not in accuracies.keys():
                accuracies[key] = []
            accuracies[key].append(measurement['cigacc'] * 100)
        return accuracies

    def get_accuracy_mean(self, include_unmatched=True):
        data = self.get_accuracy_list(include_unmatched)
        data = np.array(data)
        return data.mean()

    def get_accuracy_mean_per_bacteria(self, include_unmatched=True):
        data = self.get_accuracy_list_per_bacteria(include_unmatched)
        means = {}
        for bacteria in data.keys():
            bacteria_data = data[bacteria]
            bacteria_data = np.array(bacteria_data)
            means[bacteria] = bacteria_data.mean()
        return means

    def get_total_testing_time(self):
        total_time_seconds = 0
        data = self.file_controller.load_testing()
        for measurement in data:
            total_time_seconds += measurement['time']
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_total_training_time(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        training_times = data[:, 3]
        _, training_stop_idx = self.get_best_validation_loss()
        total_time_seconds = training_times[
            training_stop_idx] - training_times[0]
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_best_validation_loss(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        validation_losses = data[:, 2]
        min_validation_idx = -1
        min_validation_loss = 1e10
        for i, validation_loss in enumerate(validation_losses):
            if validation_loss < 0:
                continue
            if validation_loss < min_validation_loss:
                min_validation_loss = validation_loss
                min_validation_idx = i
        return min_validation_loss, min_validation_idx

    def get_SMDI(self):
        data = self.file_controller.load_testing()
        smdi_dict = {"S": 0, "M": 0, "D": 0, "I": 0}
        total_length = 0
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            total_length += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)  #[6M, 5D, ...]
            for r in result:
                amount = int(r[:-1])  # 6
                key = r[-1]  # M
                smdi_dict[key] += amount
        for key in 'SMDI':
            smdi_dict[key] /= total_length
        return smdi_dict

    def get_SMDI_per_bacteria(self):
        data = self.file_controller.load_testing()
        smdi_dicts = {}
        lenghts = {}
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            bacteria = measurement['bacteria']
            if bacteria not in smdi_dicts.keys():
                smdi_dicts[bacteria] = {"S": 0, "M": 0, "D": 0, "I": 0}
                lenghts[bacteria] = 0
            lenghts[bacteria] += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)
            for r in result:
                amount = int(r[:-1])
                key = r[-1]
                smdi_dicts[bacteria][key] += amount
        for bacteria in smdi_dicts.keys():
            for key in 'SMDI':
                smdi_dicts[bacteria][key] /= lenghts[bacteria]
        return smdi_dicts