def plot_learning_rate(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_learning_rate_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        lr = data[:, 4]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_learning_rate_plot(lr, path)
    except Exception as e:
        print_exception(' ! Unable to create learning rate plot.',
                        e,
                        show_trance=True)
def plot_validation(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_validation_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        data = data[:, 2]
        data = data[1:]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_validation_plot(data, path)
    except Exception as e:
        print_exception(' ! Unable to create validation plot.',
                        e,
                        show_trance=True)
def plot_training(experiment_name):
    try:
        file_controller = FileController(experiment_name)
        path = file_controller.get_training_plot_filepath()
        data = file_controller.load_training()
        data = np.array(data)
        loss = data[:, 0]
        acc = data[:, 1]
        plotting_controller = PlottingController(experiment_name)
        plotting_controller.save_training_plot(loss, acc, path)
    except Exception as e:
        print_exception(' ! Unable to create training plot.',
                        e,
                        show_trance=True)
class TrainingController():
    def __init__(self, config, experiment_name, model, generator,
                 validation_controller, new_training):
        training_config = config['training']
        model_config = config['model']

        self.generator = generator
        self.validation_controller = validation_controller

        self.file_controller = FileController(experiment_name)
        self.results = [] if new_training else self.file_controller.load_training(
        )

        self.epochs = training_config['epochs']
        self.patience = training_config['patience']
        self.warmup = training_config['warmup']
        self.batches = training_config['batches']
        self.batch_size = training_config['batch_size']
        self.model = model

        learning_rate = CustomSchedule(model_config['d_model'] *
                                       training_config['lr_mult'])
        self.optimizer = tf.keras.optimizers.Adam(learning_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.98,
                                                  epsilon=1e-9)
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

    def train(self):
        print(' - Training model.')
        waited = 0
        validation_loss = 1e10 if self.results == [] else self.get_best_validation_loss(
        )
        print(f' - Initial validation loss: {validation_loss}.')
        model_weights = self.model.get_weights()

        for epoch in range(self.epochs):
            waited = 0 if epoch < self.warmup else waited
            start_time = time.time()
            self.train_loss.reset_states()
            self.train_accuracy.reset_states()

            batches = next(self.generator.get_batches(self.batches))
            for batch, (x, y) in enumerate(batches):
                x = tf.constant(x, dtype=tf.float32)
                y = tf.constant(y, dtype=tf.int32)
                self.train_step(x, y)
                print(
                    f' - - Epoch:{epoch+1}/{self.epochs} | Batch:{batch+1}/{len(batches)} | Loss:{self.train_loss.result():.4f} | Accuracy:{self.train_accuracy.result():.4f}',
                    end="\r")
            print()

            current_validation_loss = self.validation_controller.validate(
                self.model)
            lr = self.get_current_learning_rate()
            self.results.append([
                self.train_loss.result(),
                self.train_accuracy.result(), current_validation_loss,
                time.time(), lr
            ])
            self.file_controller.save_training(self.results)
            print(
                f' = = Epoch:{epoch+1}/{self.epochs} | Loss:{self.train_loss.result():.4f} | Accuracy:{self.train_accuracy.result():.4f} | Validation loss:{current_validation_loss} | Took:{time.time() - start_time} secs | Learning rate:{lr:.10}'
            )

            if current_validation_loss < validation_loss:
                waited = 0
                validation_loss = current_validation_loss
                print(
                    ' - - Model validation accuracy improvement - saving model weights.'
                )
                self.file_controller.save_model(self.model)
                model_weights = self.model.get_weights()
            else:
                waited += 1
                if waited > self.patience:
                    print(
                        f' - Stopping training ( out of patience - model has not improved for {self.patience} epochs.'
                    )
                    break

        self.model.set_weights(model_weights)
        return self.model

    @tf.function
    def train_step(self, x, y):
        y_input = y[:, :-1]
        y_label = y[:, 1:]

        combined_mask = create_combined_mask(y_input)
        with tf.GradientTape() as tape:
            y_prediction, _ = self.model(x, y_input, True, combined_mask)
            loss = self.model.get_loss(y_label, y_prediction, self.loss_object)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables))
        self.train_loss(loss)
        self.train_accuracy(y_label, y_prediction)

    def get_best_validation_loss(self):
        results = np.array(self.results)
        validation_losses = results[:, 2]
        min_validation_loss = 1e10
        for validation_loss in validation_losses:
            if validation_loss < 0:
                continue
            if validation_loss < min_validation_loss:
                min_validation_loss = validation_loss
        return min_validation_loss

    def get_current_learning_rate(self):
        lr = self.optimizer._decayed_lr("float32").numpy()
        return float(lr)
Esempio n. 5
0
class EvaluationController():
    def __init__(self, experiment_name):
        self.file_controller = FileController(experiment_name)

    def count_unmatched_reads(self):
        data = self.get_accuracy_list(include_unmatched=True)
        count = 0
        for accuracy in data:
            if accuracy == 0:
                count += 1
        return count

    def count_unmatched_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            count = 0
            for accuracy in data[bacteria]:
                if accuracy == 0:
                    count += 1
            counts[bacteria] = count
        return counts

    def count_reads(self):
        data = self.file_controller.load_testing()
        return len(data)

    def count_reads_per_bacteria(self):
        data = self.get_accuracy_list_per_bacteria(include_unmatched=True)
        counts = {}
        for bacteria in data.keys():
            counts[bacteria] = len(data[bacteria])
        return counts

    def get_accuracy_list(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracy = []
        for measurement in data:
            if include_unmatched == False and measurement['cigacc'] == 0:
                continue
            accuracy.append(measurement['cigacc'] * 100)
        return accuracy

    def get_accuracy_list_per_bacteria(self, include_unmatched=True):
        data = self.file_controller.load_testing()
        accuracies = {}
        for measurement in data:
            key = measurement['bacteria']
            if key not in accuracies.keys():
                accuracies[key] = []
            accuracies[key].append(measurement['cigacc'] * 100)
        return accuracies

    def get_accuracy_mean(self, include_unmatched=True):
        data = self.get_accuracy_list(include_unmatched)
        data = np.array(data)
        return data.mean()

    def get_accuracy_mean_per_bacteria(self, include_unmatched=True):
        data = self.get_accuracy_list_per_bacteria(include_unmatched)
        means = {}
        for bacteria in data.keys():
            bacteria_data = data[bacteria]
            bacteria_data = np.array(bacteria_data)
            means[bacteria] = bacteria_data.mean()
        return means

    def get_total_testing_time(self):
        total_time_seconds = 0
        data = self.file_controller.load_testing()
        for measurement in data:
            total_time_seconds += measurement['time']
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_total_training_time(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        training_times = data[:, 3]
        _, training_stop_idx = self.get_best_validation_loss()
        total_time_seconds = training_times[
            training_stop_idx] - training_times[0]
        total_time_seconds = math.floor(total_time_seconds)
        total_time = str(datetime.timedelta(seconds=total_time_seconds))
        return total_time

    def get_best_validation_loss(self):
        data = self.file_controller.load_training()
        data = np.array(data)
        validation_losses = data[:, 2]
        min_validation_idx = -1
        min_validation_loss = 1e10
        for i, validation_loss in enumerate(validation_losses):
            if validation_loss < 0:
                continue
            if validation_loss < min_validation_loss:
                min_validation_loss = validation_loss
                min_validation_idx = i
        return min_validation_loss, min_validation_idx

    def get_SMDI(self):
        data = self.file_controller.load_testing()
        smdi_dict = {"S": 0, "M": 0, "D": 0, "I": 0}
        total_length = 0
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            total_length += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)  #[6M, 5D, ...]
            for r in result:
                amount = int(r[:-1])  # 6
                key = r[-1]  # M
                smdi_dict[key] += amount
        for key in 'SMDI':
            smdi_dict[key] /= total_length
        return smdi_dict

    def get_SMDI_per_bacteria(self):
        data = self.file_controller.load_testing()
        smdi_dicts = {}
        lenghts = {}
        for measurement in data:
            if measurement['cigacc'] == 0:
                continue
            bacteria = measurement['bacteria']
            if bacteria not in smdi_dicts.keys():
                smdi_dicts[bacteria] = {"S": 0, "M": 0, "D": 0, "I": 0}
                lenghts[bacteria] = 0
            lenghts[bacteria] += measurement['blen']
            cigar_string = measurement['cig']
            result = re.findall(r'[\d]+[SMDI]', cigar_string)
            for r in result:
                amount = int(r[:-1])
                key = r[-1]
                smdi_dicts[bacteria][key] += amount
        for bacteria in smdi_dicts.keys():
            for key in 'SMDI':
                smdi_dicts[bacteria][key] /= lenghts[bacteria]
        return smdi_dicts