def __init__(self, config, experiment_name, model, generator,
                 validation_controller, new_training):
        training_config = config['training']
        model_config = config['model']

        self.generator = generator
        self.validation_controller = validation_controller

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_training else self.file_controller.load_training(
        )

        self.epochs = training_config['epochs']
        self.patience = training_config['patience']
        self.warmup = training_config['warmup']
        self.batches = training_config['batches']
        self.batch_size = training_config['batch_size']
        self.model = model

        learning_rate = CustomSchedule(model_config['d_model'] *
                                       training_config['lr_mult'])
        self.optimizer = tf.keras.optimizers.Adam(learning_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.98,
                                                  epsilon=1e-9)
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')
def make_report(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        evaluation_controller = EvaluationController(experiment_name)
        report = {
            'mean_accuracy':
            evaluation_controller.get_accuracy_mean(),
            'mean_accuracy_per_bacteria':
            evaluation_controller.get_accuracy_mean_per_bacteria(),
            'best_editdistance':
            evaluation_controller.get_best_validation_loss(),
            'number_of_tested_reads':
            evaluation_controller.count_reads(),
            'number_of_tested_reads_per_bacteria':
            evaluation_controller.count_reads_per_bacteria(),
            'unmatched_reads':
            evaluation_controller.count_unmatched_reads(),
            'unmatched_reads_per_bacteria':
            evaluation_controller.count_unmatched_reads_per_bacteria(),
            'total_testing_time':
            evaluation_controller.get_total_testing_time(),
            'total_training_time':
            evaluation_controller.get_total_training_time(),
            'SMDI':
            evaluation_controller.get_SMDI(),
            'SMDI_per_bacteria':
            evaluation_controller.get_SMDI_per_bacteria()
        }
        file_controller.save_evaluation(report)
    except Exception as e:
        print_exception(' ! Unable to create evaluation report.',
                        e,
                        show_trance=True)
Exemplo n.º 3
0
def get_trained_model(config, experiment_name):
    model = get_new_model(config)
    file_controller = FileController(experiment_name)
    assert file_controller.trained_model_exists(
    ), ' ! Unable to load trained model. Invalid experiment name.'
    trained_model_path = file_controller.get_model_filepath()
    trained_model = model.load_weights(trained_model_path)
    return model
Exemplo n.º 4
0
    def __init__(self, config, experiment_name):
        self.config = config
        self.file_controller = FileController(experiment_name)

        self.skip_training = False
        self.new_training = False
        self.continue_training = True

        self.skip_testing = False
        self.new_testing = False
        self.continue_testing = True
def plot_testing(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_testing_plot_filepath(suffix='all')
        evaluation_controller = EvaluationController(experiment_name)
        acc = evaluation_controller.get_accuracy_list()
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_testing_plot(acc, path, title='all')
    except Exception as e:
        print_exception(' ! Unable to create testing plot.',
                        e,
                        show_trance=True)
def plot_learning_rate(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_learning_rate_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        lr = data[:, 4]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_learning_rate_plot(lr, path)
    except Exception as e:
        print_exception(' ! Unable to create learning rate plot.',
                        e,
                        show_trance=True)
    def __init__(self, config, experiment_name, model, new_testing):
        test_config = config['testing']
        model_config = config['model']
        
        self.model = model
        self.reads = test_config['reads']
        self.batch_size = test_config['batch_size']

        self.use_assembler = test_config['signal_window_stride'] < model_config['signal_window_size']
        self.save_predictions = test_config['save_predictions']
        self.inference_controller = InferenceController()

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_testing else self.file_controller.load_testing()
def plot_testing_per_bacteria(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        evaluation_controller = EvaluationController(experiment_name)
        plotting_controller = PlottingController(experiment_name)
        data = evaluation_controller.get_accuracy_list_per_bacteria()
        for bacteria in data.keys():
            acc = data[bacteria]
            path = file_controller.get_testing_plot_filepath(suffix=bacteria)
            plotting_controller.save_testing_plot(acc, path, title=bacteria)
    except Exception as e:
        print_exception(' ! Unable to create testing plot per bacteria.',
                        e,
                        show_trance=True)
def plot_validation(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_validation_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        data = data[:, 2]
        data = data[1:]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_validation_plot(data, path)
    except Exception as e:
        print_exception(' ! Unable to create validation plot.',
                        e,
                        show_trance=True)
def plot_training(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_training_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        loss = data[:, 0]
        acc = data[:, 1]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_training_plot(loss, acc, path)
    except Exception as e:
        print_exception(' ! Unable to create training plot.',
                        e,
                        show_trance=True)
Exemplo n.º 11
0
    def setUp(self) -> None:
        super().setUp()

        self.main_model = MainModel()
        self.file_view_test = FileView(self.main_model.file_model)
        self.file_controller = FileController(self.main_model.file_model,
                                              self.file_view_test)
Exemplo n.º 12
0
def setup_experiment(experiment_name, config):
    file_controller = FileController(experiment_name)
    file_controller.create_experiment_dir()
    file_controller.create_assembly_directory()
    file_controller.create_report_directory()
    file_controller.create_prediction_directory()
    file_controller.save_config(config)
Exemplo n.º 13
0
def discard_existing_evaluation(experiment_name):
    file_controller = FileController(experiment_name)
    file_controller.teardown_evaluation()
Exemplo n.º 14
0
def discard_existing_testing(experiment_name):
    file_controller = FileController(experiment_name)
    file_controller.teardown_testing()
    file_controller.teardown_assemblies()
    file_controller.teardown_evaluation()
Exemplo n.º 15
0
 def __init__(self, experiment_name):
     self.file_controller = FileController(experiment_name)
class TestingController():
    def __init__(self, config, experiment_name, model, new_testing):
        test_config = config['testing']
        model_config = config['model']
        
        self.model = model
        self.reads = test_config['reads']
        self.batch_size = test_config['batch_size']

        self.use_assembler = test_config['signal_window_stride'] < model_config['signal_window_size']
        self.save_predictions = test_config['save_predictions']
        self.inference_controller = InferenceController()

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_testing else self.file_controller.load_testing()

    def pretty_print_progress(self, start, end, total):
        try:
            progress_str = '['
            for i in range(0, total, max(total//50, 1)):
                if i >= start and i < end:
                    progress_str += 'x'
                else:
                    progress_str += '-'
            progress_str += ']'
            return progress_str
        except:
            return '[ Failed to get progress string ]'

    def get_assembly(self, y_pred, iteration, read_id, bacteria):
        if self.use_assembler == False:
            return ''.join(y_pred)
        assembly_path = self.file_controller.get_assembly_filepath(iteration, read_id, bacteria)
        return assemble_and_output(assembly_path, y_pred)

    def get_result(self, assembly, aligner, read_id, bacteria):
        try:
            besthit = next(aligner.map(assembly))
            cigacc = 1-(besthit.NM/besthit.blen)
            return self.get_result_dict(read_id, bacteria, besthit.ctg, besthit.r_st, besthit.r_en, besthit.NM, besthit.blen, besthit.cigar_str, cigacc)
        except:
            return self.get_result_dict(read_id, bacteria, 0, 0, 0, 0, 0, 0, 0)

    def save_prediction(self, prediction_str, bacteria, read_id, iteration):
        if self.save_predictions == False:
            return
        self.file_controller.save_prediction(prediction_str, bacteria, read_id, iteration)

    def get_result_dict(self, read_id, bacteria, ctg, r_st, r_en, nm, blen, cig, cigacc):
        return {
            'read_id':read_id,
            'bacteria':bacteria,
            'ctg': ctg,
            'r_st': r_st,
            'r_en': r_en,
            'NM': nm,
            'blen': blen,
            'cig': cig,
            'cigacc': cigacc
        }

    def test(self, bacteria, generator, aligner):
        print(f' - Testing {bacteria}.')
        for i in range(self.reads):
            try:
                x, read_id = next(generator.get_batched_read())
                start_time = time.time()
                y_pred = []
                for b in range(0, len(x), self.batch_size):
                    x_batch = x[b:b+self.batch_size]
                    progress_bar = self.pretty_print_progress(b, b+len(x_batch), len(x))
                    print(f"{i+1:02d}/{self.reads:02d} Predicting batch {progress_bar} {b:04d}-{b+len(x_batch):04d}/{len(x):04d}", end="\r")
                                       
                    y_batch_pred = self.inference_controller.predict_batch(x_batch, self.model)
                    y_batch_pred_strings = convert_to_base_strings(y_batch_pred)
                    y_pred.extend(y_batch_pred_strings)
                
                assembly = self.get_assembly(y_pred, i, read_id, bacteria)
                self.save_prediction(assembly, bacteria, read_id, i)
                
                result = self.get_result(assembly, aligner, read_id, bacteria)
                result['time'] = time.time() - start_time
                self.results.append(result)

                print(f"{i:02d}/{self.reads} Done | CIG ACC: {result['cigacc']}"+" "*70) # 70 blanks to overwrite the previous print
                self.file_controller.save_testing(self.results)
            except Exception as e:
                print(e)
                traceback.print_exc()
Exemplo n.º 17
0
class UIController():
    def __init__(self, config, experiment_name):
        self.config = config
        self.file_controller = FileController(experiment_name)

        self.skip_training = False
        self.new_training = False
        self.continue_training = True

        self.skip_testing = False
        self.new_testing = False
        self.continue_testing = True

    def ask_retrain(self):
        if self.file_controller.trained_model_exists() == False:
            print(' - Trained model not found. Model will be trained.')
            self.continue_training = False
            self.skip_training = False
            self.new_training = True
            return
        message = 'A trained model already exists, would you like to retrain it?'
        choices = [
            'skip training', 'continue training existing model',
            'new training (discard existing)'
        ]
        question = inquirer.List('retrain', message, choices)
        answer = inquirer.prompt([question])

        self.skip_training = answer['retrain'] == 'skip training'
        self.new_training = answer[
            'retrain'] == 'new training (discard existing)'
        self.continue_training = answer[
            'retrain'] == 'continue training existing model'

    def ask_retest(self):
        if self.file_controller.testing_result_exists() == False:
            print(' - Testing results not found. Model will be tested.')
            self.continue_testing = False
            self.skip_testing = False
            self.new_testing = True
            return
        message = 'Model testing results already exist, what would you like to do?'
        choices = self.get_retest_choices()
        question = inquirer.List('retest', message, choices)
        answer = inquirer.prompt([question])

        self.skip_testing = answer['retest'] == 'skip testing'
        self.new_testing = answer['retest'] == 'new testing (discard existing)'
        self.continue_testing = answer[
            'retest'] == 'append to existing results'

    def get_retest_choices(self):
        if self.new_training or self.continue_training:
            return [
                'skip testing', 'new testing (discard existing)'
            ]  # Do not allow to append testing analysis if model is retrained or improved
        else:
            return [
                'skip testing', 'append to existing results',
                'new testing (discard existing)'
            ]

    def ask_parameters(self):
        message = 'Continue with these experiment parameters?'
        choices = ['yes', 'no']
        question = inquirer.List('parameters', message, choices)
        answer = inquirer.prompt([question])
        if answer['parameters'] == 'no':
            sys.exit()

    def print_parameters(self, key):
        print(f' - Verify {key} parameters')
        print(json.dumps(self.config[key], indent=4, sort_keys=True))
class TrainingController():
    def __init__(self, config, experiment_name, model, generator,
                 validation_controller, new_training):
        training_config = config['training']
        model_config = config['model']

        self.generator = generator
        self.validation_controller = validation_controller

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_training else self.file_controller.load_training(
        )

        self.epochs = training_config['epochs']
        self.patience = training_config['patience']
        self.warmup = training_config['warmup']
        self.batches = training_config['batches']
        self.batch_size = training_config['batch_size']
        self.model = model

        learning_rate = CustomSchedule(model_config['d_model'] *
                                       training_config['lr_mult'])
        self.optimizer = tf.keras.optimizers.Adam(learning_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.98,
                                                  epsilon=1e-9)
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

    def train(self):
        print(' - Training model.')
        waited = 0
        validation_loss = 1e10 if self.results == [] else self.get_best_validation_loss(
        )
        print(f' - Initial validation loss: {validation_loss}.')
        model_weights = self.model.get_weights()

        for epoch in range(self.epochs):
            waited = 0 if epoch < self.warmup else waited
            start_time = time.time()
            self.train_loss.reset_states()
            self.train_accuracy.reset_states()

            batches = next(self.generator.get_batches(self.batches))
            for batch, (x, y) in enumerate(batches):
                x = tf.constant(x, dtype=tf.float32)
                y = tf.constant(y, dtype=tf.int32)
                self.train_step(x, y)
                print(
                    f' - - Epoch:{epoch+1}/{self.epochs} | Batch:{batch+1}/{len(batches)} | Loss:{self.train_loss.result():.4f} | Accuracy:{self.train_accuracy.result():.4f}',
                    end="\r")
            print()

            current_validation_loss = self.validation_controller.validate(
                self.model)
            lr = self.get_current_learning_rate()
            self.results.append([
                self.train_loss.result(),
                self.train_accuracy.result(), current_validation_loss,
                time.time(), lr
            ])
            self.file_controller.save_training(self.results)
            print(
                f' = = Epoch:{epoch+1}/{self.epochs} | Loss:{self.train_loss.result():.4f} | Accuracy:{self.train_accuracy.result():.4f} | Validation loss:{current_validation_loss} | Took:{time.time() - start_time} secs | Learning rate:{lr:.10}'
            )

            if current_validation_loss < validation_loss:
                waited = 0
                validation_loss = current_validation_loss
                print(
                    ' - - Model validation accuracy improvement - saving model weights.'
                )
                self.file_controller.save_model(self.model)
                model_weights = self.model.get_weights()
            else:
                waited += 1
                if waited > self.patience:
                    print(
                        f' - Stopping training ( out of patience - model has not improved for {self.patience} epochs.'
                    )
                    break

        self.model.set_weights(model_weights)
        return self.model

    @tf.function
    def train_step(self, x, y):
        y_input = y[:, :-1]
        y_label = y[:, 1:]

        combined_mask = create_combined_mask(y_input)
        with tf.GradientTape() as tape:
            y_prediction, _ = self.model(x, y_input, True, combined_mask)
            loss = self.model.get_loss(y_label, y_prediction, self.loss_object)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables))
        self.train_loss(loss)
        self.train_accuracy(y_label, y_prediction)

    def get_best_validation_loss(self):
        results = np.array(self.results)
        validation_losses = results[:, 2]
        min_validation_loss = 1e10
        for validation_loss in validation_losses:
            if validation_loss < 0:
                continue
            if validation_loss < min_validation_loss:
                min_validation_loss = validation_loss
        return min_validation_loss

    def get_current_learning_rate(self):
        lr = self.optimizer._decayed_lr("float32").numpy()
        return float(lr)
Exemplo n.º 19
0
def discard_existing_training(experiment_name):
    file_controller = FileController(experiment_name)
    file_controller.teardown_training()
    file_controller.teardown_model()
Exemplo n.º 20
0
class EvaluationController():
    def __init__(self, experiment_name):
        self.file_controller = FileController(experiment_name)

    def count_unmatched_reads(self):
        data = self.get_accuracy_list(include_unmatched=True)
        count = 0
        for accuracy in data:
            if accuracy == 0:
                count += 1
        return count

    def count_unmatched_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            count = 0
            for accuracy in data[bacteria]:
                if accuracy == 0:
                    count += 1
            counts[bacteria] = count
        return counts

    def count_reads(self):
        data = self.file_controller.load_testing()
        return len(data)

    def count_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            counts[bacteria] = len(data[bacteria])
        return counts

    def get_accuracy_list(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracy = []
        for measurement in data:
            if include_unmatched == False and measurement['cigacc'] == 0:
                continue
            accuracy.append(measurement['cigacc'] * 100)
        return accuracy

    def get_accuracy_list_per_bacteria(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracies = {}
        for measurement in data:
            key = measurement['bacteria']
            if key not in accuracies.keys():
                accuracies[key] = []
            accuracies[key].append(measurement['cigacc'] * 100)
        return accuracies

    def get_accuracy_mean(self, include_unmatched=True):
        data = self.get_accuracy_list(include_unmatched)
        data = np.array(data)
        return data.mean()

    def get_accuracy_mean_per_bacteria(self, include_unmatched=True):
        data = self.get_accuracy_list_per_bacteria(include_unmatched)
        means = {}
        for bacteria in data.keys():
            bacteria_data = data[bacteria]
            bacteria_data = np.array(bacteria_data)
            means[bacteria] = bacteria_data.mean()
        return means

    def get_total_testing_time(self):
        total_time_seconds = 0
        data = self.file_controller.load_testing()
        for measurement in data:
            total_time_seconds += measurement['time']
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_total_training_time(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        training_times = data[:, 3]
        _, training_stop_idx = self.get_best_validation_loss()
        total_time_seconds = training_times[
            training_stop_idx] - training_times[0]
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_best_validation_loss(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        validation_losses = data[:, 2]
        min_validation_idx = -1
        min_validation_loss = 1e10
        for i, validation_loss in enumerate(validation_losses):
            if validation_loss < 0:
                continue
            if validation_loss < min_validation_loss:
                min_validation_loss = validation_loss
                min_validation_idx = i
        return min_validation_loss, min_validation_idx

    def get_SMDI(self):
        data = self.file_controller.load_testing()
        smdi_dict = {"S": 0, "M": 0, "D": 0, "I": 0}
        total_length = 0
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            total_length += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)  #[6M, 5D, ...]
            for r in result:
                amount = int(r[:-1])  # 6
                key = r[-1]  # M
                smdi_dict[key] += amount
        for key in 'SMDI':
            smdi_dict[key] /= total_length
        return smdi_dict

    def get_SMDI_per_bacteria(self):
        data = self.file_controller.load_testing()
        smdi_dicts = {}
        lenghts = {}
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            bacteria = measurement['bacteria']
            if bacteria not in smdi_dicts.keys():
                smdi_dicts[bacteria] = {"S": 0, "M": 0, "D": 0, "I": 0}
                lenghts[bacteria] = 0
            lenghts[bacteria] += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)
            for r in result:
                amount = int(r[:-1])
                key = r[-1]
                smdi_dicts[bacteria][key] += amount
        for bacteria in smdi_dicts.keys():
            for key in 'SMDI':
                smdi_dicts[bacteria][key] /= lenghts[bacteria]
        return smdi_dicts