Ejemplo n.º 1
0
    def validate(self):
        self.eval()

        for inputs, labels, data in self.val_dl:
            loss, output = self.forward(inputs, labels)
            self.meter_val.update(labels, output.detach().cpu(), loss.item())

        info = self.meter_val.log_metric()
        selection_metric = info["acc"]  # not using loss but pos loss

        if selection_metric <= self.best_metric:
            self.best_metric = selection_metric
            print(f'>>> Saving best model metric={selection_metric:.4f}')
            checkpoint = {'model': self.model}
            torch.save(checkpoint, 'checkpoints/best_model.pth')
            train_info = self.meter_train.log_metric(write_scalar=False)
            self.history_best = {"train_"+key: value for key, value in train_info.items()}
            for key, value in info.items():
                self.history_best["val_"+key] = value
            self.history_best["epoch"] = self.current_epoch
            if settings.USE_FOUNDATIONS:
                foundations.save_artifact('checkpoints/best_model.pth', key='best_model_checkpoint')

        try:
            inputs, labels, data = next(self.visual_iter)
        except:
            self.visual_iter = iter(self.val_dl)
            inputs, labels, data = next(self.visual_iter)

        _, output = self.forward(inputs, labels)
        output = torch.sigmoid(output.detach().cpu())
        inputs = inputs.view((-1, ) + inputs.shape[-3:])
        self.writer.add_images(f'validate/{self.current_epoch}_inputs.png', self.unnorm(inputs)[:8], self.current_epoch)
        print(f'Epoch {self.current_epoch}: val loss={info["loss"]:.4f} | val acc={info["acc"]:.4f}')
def plot_results(model, imgs_validation, msks_validation,
                 img_no, png_directory):
    """
    Calculate the Dice and plot the predicted masks for image # img_no
    """

    img = imgs_validation[img_no: img_no+1]
    msk = msks_validation[img_no: img_no+1]

    pred_mask = model.predict(img)

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(img[0, :, :, 0], cmap="bone", origin="lower")
    plt.title("MRI")
    plt.axis("off")
    plt.subplot(1, 3, 2)
    plt.imshow(msk[0, :, :, 0], origin="lower")
    plt.title("Ground Truth")
    plt.axis("off")
    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask[0, :, :, 0], origin="lower")
    plt.title("Prediction\n(Dice = {:.4f})".format(calc_dice(msk, pred_mask)))
    plt.axis("off")

    png_filename = os.path.join(png_directory, "pred_{}.png".format(img_no))
    plt.savefig(png_filename, bbox_inches="tight", pad_inches=0)

    foundations.save_artifact(png_filename)
    
    print("Dice {:.4f}, Soft Dice {:.4f}, Saved png file to: {}".format(
        calc_dice(msk, pred_mask), calc_soft_dice(msk, pred_mask), png_filename))
Ejemplo n.º 3
0
def train_model(classifier, feature_vector_train, label, feature_vector_valid,
                valid_y):
    # fit the training dataset on the classifier
    callbacks = []
    tb = TensorBoard(log_dir='tflogs', write_graph=True, write_grads=False)
    callbacks.append(tb)
    es = EarlyStopping(monitor='val_loss',
                       mode='min',
                       patience=5,
                       min_delta=0.0001,
                       verbose=1)
    callbacks.append(tb)
    callbacks.append(es)

    rp = ReduceLROnPlateau(monitor='val_loss',
                           factor=0.6,
                           patience=1,
                           verbose=1)
    callbacks.append(rp)
    history = classifier.fit(feature_vector_train,
                             label,
                             epochs=model_params_conv['epochs'],
                             validation_split=0.2,
                             batch_size=model_params_conv['batch_size'],
                             callbacks=callbacks)

    history_dict = history.history
    print(
        f"There are the keys of history_dict {np.array(history_dict.keys())}")
    train_loss = history_dict['loss']
    val_loss = history_dict['val_loss']
    train_acc = history_dict['accuracy']
    val_acc = history_dict['val_accuracy']
    inp_list = [train_loss, val_loss, train_acc, val_acc]
    fig_name = ['train_loss', 'val_loss', 'train_acc', 'val_acc']
    save_plot_all(inp_list, fig_name)

    try:
        foundations.save_artifact('performance_plots.png',
                                  key='performance_plots')
    except Exception as e:
        print(e)

    # predict the labels on validation dataset
    predictions = classifier.predict(feature_vector_valid)

    #round probabilities to 1 and 0
    predictions = predictions.round()

    return metrics.accuracy_score(predictions,
                                  valid_y), predictions, classifier
Ejemplo n.º 4
0
    def train(self, xtrain, ytrain, xval, yval):
        callbacks = []
        tb = TensorBoard(log_dir='tflogs', write_graph=True, write_grads=False)
        callbacks.append(tb)

        try:
            foundations.set_tensorboard_logdir('tflogs')
        except:
            print("foundations command not found")

        es = EarlyStopping(monitor='val_loss', mode='min', patience=5, min_delta=0.0001,
                           verbose=1)
        callbacks.append(tb)
        callbacks.append(es)

        rp = ReduceLROnPlateau(monitor='val_loss', factor=0.6, patience=2,
                               verbose=1)
        callbacks.append(rp)

        f1_callback = f1_score_callback(xval, yval, model_save_filename=self.model_save_filename)
        callbacks.append(f1_callback)

        class_weights = {1: 5, 0: 1}

        train_generator = DataGenerator(xtrain, ytrain)
        validation_generator = DataGenerator(xval, yval)
        self.model.fit_generator(train_generator,
                                 steps_per_epoch = len(train_generator),
                                 epochs = model_params['epochs'],
                                 validation_data=validation_generator,
                                 callbacks = callbacks,
                                 shuffle = False,
                                  use_multiprocessing = True,
                                  verbose = 1,
                                 class_weight =class_weights)

        self.model = load_model(self.model_save_filename, custom_objects={'customPooling': customPooling})

        try:
            foundations.save_artifact(self.model_save_filename, key='trained_model.h5')
        except:
            print("foundations command not found")
    def _save_artifacts(klass):
        import foundations

        klass._set_job_id(klass._one_artifact)
        foundations.save_artifact(
            filepath=klass._artifact_fixture_path('image_file.png'))

        klass._set_job_id(klass._some_artifacts)
        foundations.save_artifact(
            filepath=klass._artifact_fixture_path('no_extension'))
        foundations.save_artifact(
            filepath=klass._artifact_fixture_path('other_file.other'))
        foundations.save_artifact(
            filepath=klass._artifact_fixture_path('audio_file.mp3'),
            key='audio_artifact')
Ejemplo n.º 6
0
    def validate(self):
        self.eval()

        for inputs, labels, data in self.val_dl:
            loss, output = self.forward(inputs, labels)
            output = output.detach().cpu()
            self.meter_val.update(labels, output, loss.item())

        dices, iou, loss = self.meter_val.log_metric()
        selection_metric = loss

        if selection_metric <= self.best_metric:
            self.best_metric = selection_metric
            print(f'>>> Saving best model metric={selection_metric:.4f}')
            checkpoint = {'model': self.model}
            torch.save(checkpoint, 'checkpoints/best_model.pth')
            if settings.USE_FOUNDATIONS:
                foundations.save_artifact('checkpoints/best_model.pth', key='best_model_checkpoint')

                foundations.log_metric("train_loss", float(np.mean(self.meter_train.losses)))
                foundations.log_metric("val_loss", float(loss))
                foundations.log_metric("val_dice", float(dices[0]))
                foundations.log_metric("val_iou", float(iou))

        try:
            inputs, labels, data = next(self.visual_iter)
        except:
            self.visual_iter = iter(self.val_dl)
            inputs, labels, data = next(self.visual_iter)

        _, output = self.forward(inputs, labels)
        output = torch.sigmoid(output.detach().cpu())
        self.writer.add_images(f'validate/{self.current_epoch}_inputs.png', self.unnorm(inputs), self.current_epoch)
        self.writer.add_images(f'validate/{self.current_epoch}_mask.png', labels, self.current_epoch)
        self.writer.add_images(f'validate/{self.current_epoch}_predict.png',  output, self.current_epoch)
        print(f'Epoch {self.current_epoch}: val loss={loss:.4f} | val iou={iou:.4f}')
Ejemplo n.º 7
0
"""
This sample main.py shows basic Atlas functionality.
In this script, we will log some arbitrary values & artifacts that can be viewed in the Atlas GUI
"""

import foundations

depth = 3
epochs = 5
batch_size = 256
lrate = 1e-3


# Log some hyper-parameters
foundations.log_param('depth', depth)
foundations.log_params({'epochs': epochs,
                        'batch_size': batch_size,
                        'learning_rate': lrate})

# Log some metrics
accuracy = 0.9
loss = 0.1
foundations.log_metric('accuracy', accuracy)
foundations.log_metric('loss', loss)

# Log an artifact that is already saved to disk
foundations.save_artifact('README.txt', 'Project_README')
Ejemplo n.º 8
0
    def evaluate(self, xtrain, ytrain, xval, yval, num_examples=1):
        ytrain_pred = self.predict_labels(xtrain, raw_prob=True)
        yval_pred = self.predict_labels(xval, raw_prob=True)
        try:
            self.optimum_threshold_filename = f"model_threshold_{'_'.join(str(v) for k, v in model_params.items())}.npy"
            self.opt_threshold = np.load(os.path.join(f"{model_params['model_save_dir']}",self.optimum_threshold_filename)).item()
            print(f"loaded optimum threshold: {self.opt_threshold}")
        except:
            self.opt_threshold = 0.5


        ytrain_pred_labels = self.get_labels_from_prob(ytrain_pred, threshold=self.opt_threshold)
        yval_pred_labels = self.get_labels_from_prob(yval_pred, threshold=self.opt_threshold)

        train_accuracy = accuracy_score(ytrain, ytrain_pred_labels)
        val_accuracy = accuracy_score(yval, yval_pred_labels)

        train_f1_score = f1_score(ytrain, ytrain_pred_labels)
        val_f1_score = f1_score(yval, yval_pred_labels)
        print (f"train accuracy: {train_accuracy}, train_f1_score: {train_f1_score},"
               f"val accuracy: {val_accuracy}, val_f1_score: {val_f1_score} ")

        try:
            foundations.log_metric('train_accuracy',np.round(train_accuracy,2))
            foundations.log_metric('val_accuracy', np.round(val_accuracy,2))
            foundations.log_metric('train_f1_score', np.round(train_f1_score,2))
            foundations.log_metric('val_f1_score', np.round(val_f1_score,2))
            foundations.log_metric('optimum_threshold', np.round(self.opt_threshold,2))
        except Exception as e:
            print(e)

        # True Positive Example
        ind_tp = np.argwhere(np.equal((yval_pred_labels + yval).astype(int), 2)).reshape(-1, )

        # True Negative Example
        ind_tn = np.argwhere(np.equal((yval_pred_labels + yval).astype(int), 0)).reshape(-1, )

        # False Positive Example
        ind_fp =np.argwhere( np.greater(yval_pred_labels, yval)).reshape(-1, )

        # False Negative Example
        ind_fn = np.argwhere(np.greater(yval, yval_pred_labels)).reshape(-1, )


        path_to_save_spetrograms = './spectrograms'
        if not os.path.isdir(path_to_save_spetrograms):
            os.makedirs(path_to_save_spetrograms)
        specs_saved = os.listdir(path_to_save_spetrograms)
        if len(specs_saved)>0:
            for file_ in specs_saved:
                os.remove(os.path.join(path_to_save_spetrograms,file_))

        ind_random_tp = np.random.choice(ind_tp, num_examples).reshape(-1,)
        tp_x = [xtrain[i] for i in ind_random_tp]

        ind_random_tn = np.random.choice(ind_tn, num_examples).reshape(-1,)
        tn_x = [xtrain[i] for i in ind_random_tn]

        ind_random_fp = np.random.choice(ind_fp, num_examples).reshape(-1,)
        fp_x = [xtrain[i] for i in ind_random_fp]

        ind_random_fn = np.random.choice(ind_fn, num_examples).reshape(-1,)
        fn_x = [xtrain[i] for i in ind_random_fn]

        print("Plotting spectrograms to show what the hell the model has learned")
        for i in range(num_examples):
            plot_spectrogram(tp_x[i], path=os.path.join(path_to_save_spetrograms, f'true_positive_{i}.png'))
            plot_spectrogram(tn_x[i], path=os.path.join(path_to_save_spetrograms,f'true_negative_{i}.png'))
            plot_spectrogram(fp_x[i], path=os.path.join(path_to_save_spetrograms,f'false_positive_{i}.png'))
            plot_spectrogram(fn_x[i], path=os.path.join(path_to_save_spetrograms,f'fale_negative_{i}.png'))

        try:
            foundations.save_artifact(os.path.join(path_to_save_spetrograms, f'true_positive_{i}.png'), key='true_positive_example')
            foundations.save_artifact(os.path.join(path_to_save_spetrograms,f'true_negative_{i}.png'), key='true_negative_example')
            foundations.save_artifact(os.path.join(path_to_save_spetrograms,f'false_positive_{i}.png'), key='false_positive_example')
            foundations.save_artifact(os.path.join(path_to_save_spetrograms,f'fale_negative_{i}.png'), key='false_negative_example')

        except Exception as e:
            print(e)
Ejemplo n.º 9
0
import foundations
from foundations_contrib.global_state import current_foundations_context, redis_connection

foundations.log_metric('ugh', 10)

with open('thomas_text.txt', 'w') as f:
    f.write('ugh_square')

foundations.save_artifact('thomas_text.txt', 'just_some_artifact')
foundations.log_param('blah', 20)

redis_connection.set('foundations_testing_job_id', current_foundations_context().pipeline_context().job_id)
params = load_parameters()
seed_everything(params['seed'])
log_params(params)

params = parse_params(params)
print(params)

model = CIFAR_Module(params).cuda()
lr_logger = LearningRateLogger()
logger = TensorBoardLogger("../logs", name=params["backbone"])
if USE_FOUNDATIONS:
    from foundations import set_tensorboard_logdir
    set_tensorboard_logdir(f'../logs/{params["backbone"]}')

checkpoint_callback = ModelCheckpoint(save_top_k=1,
                                      monitor='acc',
                                      prefix=str(params["seed"]))
t_params = get_trainer_params(params)
trainer = Trainer(callbacks=[lr_logger],
                  logger=logger,
                  checkpoint_callback=checkpoint_callback,
                  **t_params)
trainer.fit(model)

if USE_FOUNDATIONS and checkpoint_callback.best_model_path != "":
    from foundations import log_metric, save_artifact
    save_artifact(checkpoint_callback.best_model_path,
                  key='best_model_checkpoint')
    log_metric("val_acc", float(checkpoint_callback.best_model_score))

print("Training finished")
Ejemplo n.º 11
0
import foundations

foundations.save_artifact("fan-man.png", "Image")
foundations.save_artifact("ICQ Uh Oh.mp3", "mp3")
foundations.save_artifact("dogge.mp4", "mp4")
foundations.save_artifact("cat.gif", "gif")
foundations.save_artifact("wilhelm.wav", "wave")
Ejemplo n.º 12
0
import os
import os.path as path

import foundations

cwd = os.getcwd()

foundations.save_artifact(filepath=path.join(cwd, 'cool-artifact.txt'))
foundations.save_artifact(
    filepath=path.join(cwd, 'other', 'cool-artifact.txt'))
Ejemplo n.º 13
0
def train(train_dl, val_dl, test_dl, val_dl_iter, model, optimizer, scheduler,
          criterion, params):
    n_epochs = params['n_epochs']
    max_lr = params['max_lr']
    val_rate = params['val_rate']
    batch_repeat = params['batch_repeat']
    records = Records()
    best_metric = 1e9

    os.makedirs('checkpoints', exist_ok=True)

    for epoch in range(n_epochs):
        train_one_epoch(epoch, model, train_dl, max_lr, optimizer, criterion,
                        scheduler, records, batch_repeat)
        if epoch % val_rate == 0:
            validate(model, val_dl, criterion, records)
            # validate(model, test_dl, criterion, records)

            selection_metric = getattr(records, "val_losses")[-1]

            if selection_metric <= best_metric:
                best_metric = selection_metric
                print(
                    f'>>> Saving best model metric={selection_metric:.4f} compared to previous best {best_metric:.4f}'
                )
                checkpoint = {'model': model}

                torch.save(checkpoint, 'checkpoints/best_model.pth')
                if settings.USE_FOUNDATIONS:
                    foundations.save_artifact('checkpoints/best_model.pth',
                                              key='best_model_checkpoint')

            # Save eyeball plot to Atlas GUI
            if settings.USE_FOUNDATIONS:
                display_filename = f'{epoch}_display.png'
                try:
                    data = next(val_dl_iter)
                except:
                    val_dl_iter = iter(val_dl)
                    data = next(val_dl_iter)
                # display_predictions_on_image(model, data, name=display_filename)
                # foundations.save_artifact(display_filename, key=f'{epoch}_display')

            # Save metrics plot
            visualize_metrics(records,
                              extra_metric=extra_metric,
                              name='metrics.png')

            # Save metrics plot to Atlas GUI
            if settings.USE_FOUNDATIONS:
                foundations.save_artifact('metrics.png', key='metrics_plot')

    # Log metrics to GUI
    max_index = np.argmin(getattr(records, 'val_losses'))

    useful_metrics = records.get_useful_metrics()
    for metric in useful_metrics:
        if settings.USE_FOUNDATIONS:
            foundations.log_metric(metric,
                                   float(getattr(records, metric)[max_index]))
        else:
            print(metric, float(getattr(records, metric)[max_index]))
Ejemplo n.º 14
0
    def training_loop(iteration):
        """The main training loop encapsulated in a function."""
        step = 0
        epoch = 0
        print("Running training loop")
        while True:
            sess.run(dataset.train_initializer)
            epoch += 1

            # End training if we have passed the epoch limit.
            #if training_len[0] == 'epochs' and epoch > NUM_EPOCHS: #training_len[1]:
            if epoch > NUM_EPOCHS:  #training_len[1]:
                break

            start_time = time.time()
            # One training epoch.
            print("Epoch: {} out of {}".format(epoch,
                                               NUM_EPOCHS))  #training_len[1]))
            while True:
                try:
                    step += 1

                    # End training if we have passed the step limit.
                    # training_len = ('iterations', 50000)
                    if training_len[
                            0] == 'iterations' and step > training_len[1]:
                        return

                    # Train.

                    step_time = time.time()
                    records = sess.run([
                        optimize, model.loss, model.targets, model.outputs,
                        model.inputs
                    ] + model.train_summaries,
                                       {dataset.handle: train_handle})[1:]
                    loss, targets, outputs, inputs = records[0], records[
                        1], records[2], records[3]

                    records = records[4:]

                    record_summaries(step, records, train_file)

                    #print(step)
                    if step % 10 == 0:
                        logger.info(
                            "Step {} - Loss: {} - Time per step: {}".format(
                                step, loss,
                                time.time() - step_time))

                    collect_test_summaries(step)

                except tf.errors.OutOfRangeError:
                    break
            logger.info("Time for epoch: {}".format(time.time() - start_time))

        outputs = output_to_rgb(outputs)
        targets = output_to_rgb(targets)

        inputs_artifact_path = save_image(
            inputs, 'inputs_{}'.format(iteration) + '.png')
        targets_artifact_path = save_image(
            targets, 'targets_{}'.format(iteration) + '.png')
        outputs_artifact_path = save_image(
            outputs, 'outputs_{}'.format(iteration) + '.png')

        tensorboard_path = 'lottery_ticket/{}/unet/summaries/'.format(
            iteration)
        tensorboard_file = os.path.join(tensorboard_path,
                                        os.listdir(tensorboard_path)[0])

        f9s.save_artifact(tensorboard_file, 'tensorboard_{}'.format(iteration))

        f9s.log_metric('loss_{}'.format(iteration), float(loss))

        f9s.save_artifact(inputs_artifact_path, 'inputs_{}'.format(iteration))
        f9s.save_artifact(targets_artifact_path,
                          'targets_{}'.format(iteration))
        f9s.save_artifact(outputs_artifact_path,
                          'outputs_{}'.format(iteration))

        # End of epoch handling.
        return
Ejemplo n.º 15
0
def train(train_dl, val_dl, test_dl, val_dl_iter, model, optimizer, scheduler,
          criterion, params, train_sampler, val_sampler, rank):
    n_epochs = params['n_epochs']
    max_lr = params['max_lr']
    val_rate = params['val_rate']
    batch_repeat = params['batch_repeat']
    history_best = {}
    best_metric = 0

    if rank == 0:
        os.makedirs('checkpoints', exist_ok=True)
        os.makedirs('tensorboard', exist_ok=True)
        if settings.USE_FOUNDATIONS:
            foundations.set_tensorboard_logdir('tensorboard')
        writer = SummaryWriter("tensorboard")
    else:
        writer = None

    for epoch in range(n_epochs):
        train_records = DistributedClassificationMeter(writer=writer,
                                                       phase="train",
                                                       epoch=epoch,
                                                       workers=params["gpus"],
                                                       criterion=criterion)
        if train_sampler:
            train_sampler.set_epoch(epoch)
        train_one_epoch(epoch, model, train_dl, max_lr, optimizer, criterion,
                        scheduler, train_records, batch_repeat, rank, writer,
                        params)
        if epoch % val_rate == 0:
            val_records = DistributedClassificationMeter(
                writer=writer,
                phase="validation",
                epoch=epoch,
                workers=params["gpus"],
                criterion=criterion)
            if val_sampler:
                val_sampler.set_epoch(epoch)
            validate(model, val_dl, criterion, val_records, rank)

            # 改的时候记得改大于小于啊!!!
            # aaaa记得改初始值啊
            info = val_records.log_metric(write_scalar=False)
            selection_metric = info["acc"]

            if selection_metric >= best_metric and rank == 0:
                best_metric = selection_metric
                print(
                    f'>>> Saving best model metric={selection_metric:.4f} compared to previous best {best_metric:.4f}'
                )
                checkpoint = {
                    'model': model.module.state_dict(),
                    'params': params
                }
                history_best = {
                    "train_" + key: value
                    for key, value in train_records.get_metric().items()
                }
                for key, value in val_records.get_metric().items():
                    history_best["val_" + key] = value

                torch.save(checkpoint, 'checkpoints/best_model.pth')
                if settings.USE_FOUNDATIONS:
                    foundations.save_artifact('checkpoints/best_model.pth',
                                              key='best_model_checkpoint')

    # Log metrics to GUI
    if rank == 0:
        for metric, value in history_best.items():
            if settings.USE_FOUNDATIONS:
                foundations.log_metric(metric, float(value))
            else:
                print(metric, float(value))
Ejemplo n.º 16
0
import os
import os.path as path

import foundations

cwd = os.getcwd()

foundations.save_artifact(filepath=path.join(cwd, 'cool-artifact.txt'),
                          key='this-key')
foundations.save_artifact(filepath=path.join(cwd, 'cooler-artifact.other'),
                          key='this-key')
Ejemplo n.º 17
0
def train(train_dl, val_base_dl, val_augment_dl, display_dl_iter, model,
          optimizer, n_epochs, max_lr, scheduler, criterion, train_source):
    records = Records()
    best_metric = 0.

    os.makedirs('checkpoints', exist_ok=True)

    for epoch in range(n_epochs):
        train_one_epoch(epoch, model, train_dl, max_lr, optimizer, criterion,
                        scheduler, records)
        validate(model, val_base_dl, criterion, records, data_name='base')
        validate(model,
                 val_augment_dl,
                 criterion,
                 records,
                 data_name='augment')

        if train_source == 'both':
            selection_metric = [
                getattr(records, 'base_val_accs')[-1],
                getattr(records, 'augment_val_accs')[-1]
            ]
            selection_metric = np.mean(selection_metric)

        else:
            selection_metric = getattr(records, f"{train_source}_val_accs")[-1]

        if selection_metric >= best_metric:
            print(
                f'>>> Saving best model metric={selection_metric:.4f} compared to previous best {best_metric:.4f}'
            )
            checkpoint = {
                'model': model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            torch.save(checkpoint, 'checkpoints/best_model.pth')
            foundations.save_artifact('checkpoints/best_model.pth',
                                      key='pretrained_model_checkpoint')

        display_filename = f'{epoch}_display.png'
        display_predictions_on_image(model,
                                     val_base_dl.dataset.cached_path,
                                     display_dl_iter,
                                     name=display_filename)

        # Save eyeball plot to Atlas GUI
        foundations.save_artifact(display_filename, key=f'{epoch}_display')

        # Save metrics plot
        visualize_metrics(records,
                          extra_metric=extra_metric,
                          name='metrics.png')

        # Save metrics plot to Atlas GUI
        foundations.save_artifact('metrics.png', key='metrics_plot')

    # Log metrics to GUI
    if train_source == 'both':
        avg_metric = [
            getattr(records, 'base_val_accs'),
            getattr(records, 'augment_val_accs')
        ]
        avg_metric = np.mean(avg_metric, axis=0)
        max_index = np.argmax(avg_metric)

    else:
        max_index = np.argmax(getattr(records, f'{train_source}_val_accs'))

    useful_metrics = records.get_metrics()
    for metric in useful_metrics:
        foundations.log_metric(metric,
                               float(getattr(records, metric)[max_index]))
Ejemplo n.º 18
0
    model.compile(optimizer=optimizers.Adam(
        lr=model_params['learning_rate'],
        decay=model_params['learning_rate'] / model_params['epochs']),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    return model


# Initialize the model
classifier = create_cnn(model_params_conv)
# Train the model
accuracy, val_predictions, trained_model = train_model(classifier, train_seq_x,
                                                       train_y, valid_seq_x,
                                                       valid_y)
# Evaluate the model
print("Validation Accuracy of Trained CNN Model", accuracy)

# Log metrics to track in foundations
try:
    foundations.log_metric('val_accuracy', accuracy)
except Exception as e:
    print(e)

# save the trained model that can be put in production if needed
trained_model.save('saved_tf_model.h5')
try:
    foundations.save_artifact('saved_tf_model.h5', key='saved_tf_model')
except Exception as e:
    print(e)