Example #1
0
def run():
    N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps = get_parameters()

    current_trainer = get_trainer(N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps)

    # TRAINING WILL SAVE current_trainer IN ModelCheckPoint
    current_trainer.train(num_epochs)

    # YOU CAN ALSO SAVE THIS TRAINER MANUALLY, LIKE THE CODE BELOW
    current_trainer.save_trainer(save_to_dir, trainer_file_name + '_manual_save')

    # NOW LETS CREATE A PREDICTOR FROM THE SAVED FILE
    device = tu.get_gpu_device_if_available()
    model = eu.get_basic_model(D_in, H, D_out).to(device)
    predictor = Predictor(model=model, device=device)

    data_generator_for_predictions = eu.examples_prediction_data_generator(data_loader, data_loader_steps)

    #PREDICT ON A SINGLE SAMPLE
    sample = next(data_generator_for_predictions)[0]
    sample_prediction = predictor.predict_sample(sample)

    # PREDICT ON A SINGLE BATCH
    batch = next(data_generator_for_predictions)
    batch_prediction = predictor.predict_batch(batch)

    # PREDICTION ON A DATA LOADER
    data_loader_predictions = predictor.predict_data_loader(data_generator_for_predictions, data_loader_steps)
Example #2
0
def get_trainer(N, D_in, H, D_out, data_loader, data_loader_steps):

    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.MSELoss(reduction='sum').to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    scheduler = DoNothingToLR(
    )  #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT

    metrics = None  # THIS EXAMPLE DOES NOT USE METRICS, ONLY LOSS

    callbacks = [LossOptimizerHandler(), StatsPrint()]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metrics=metrics,
                      train_data_loader=data_loader,
                      val_data_loader=data_loader,
                      train_steps=data_loader_steps,
                      val_steps=data_loader_steps,
                      callbacks=callbacks,
                      name='Train-Evaluate-Predict-Example')
    return trainer
Example #3
0
File: model.py Project: urialon/lpd
def get_trainer(config, num_embeddings, train_data_loader, val_data_loader,
                train_steps, val_steps, checkpoint_dir, checkpoint_file_name,
                summary_writer_dir, num_epochs):
    device = tu.get_gpu_device_if_available()

    model = TestModel(config, num_embeddings).to(device)

    optimizer = optim.SGD(params=model.parameters(),
                          lr=config.LEARNING_RATE,
                          momentum=0.9)
    # optimizer = optim.Adam(params=model.parameters(), lr=config.LEARNING_RATE)

    # scheduler = DoNothingToLR(optimizer=optimizer)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        patience=config.EARLY_STOPPING_PATIENCE // 2,
        verbose=True
    )  # needs SchedulerStep callback WITH scheduler_parameters_func
    # scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, gamma=config.STEP_LR_GAMMA, step_size=config.STEP_LR_STEP_SIZE) # needs SchedulerStep callback WITHOUT scheduler_parameters_func

    loss_func = nn.BCEWithLogitsLoss().to(device)

    metric_name_to_func = {"acc": binary_accuracy_with_logits}

    callbacks = [
        SchedulerStep(scheduler_parameters_func=lambda trainer: trainer.
                      val_stats.get_loss()),
        ModelCheckPoint(checkpoint_dir,
                        checkpoint_file_name,
                        monitor='val_loss',
                        save_best_only=True,
                        round_values_on_print_to=7),
        Tensorboard(summary_writer_dir=summary_writer_dir),
        EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE,
                      monitor='val_loss'),
        EpochEndStats(
            cb_phase=cbs.CB_ON_EPOCH_END, round_values_on_print_to=7
        )  # BETTER TO PUT IT LAST (MAKES BETTER SENSE IN THE LOG PRINTS)
    ]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metric_name_to_func=metric_name_to_func,
                      train_data_loader=train_data_loader,
                      val_data_loader=val_data_loader,
                      train_steps=train_steps,
                      val_steps=val_steps,
                      num_epochs=num_epochs,
                      callbacks=callbacks,
                      name='Multi-Input-Example')
    return trainer
Example #4
0
def get_trainer_base(D_in, H, D_out):
    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.BCEWithLogitsLoss().to(device)
   
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    scheduler = DoNothingToLR() #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT
    
    metrics = BinaryAccuracyWithLogits(name='Accuracy')

    return device, model, loss_func, optimizer, scheduler, metrics
Example #5
0
    def test_metrics_validation(self):
        device = tu.get_gpu_device_if_available()

        model = eu.get_basic_model(10, 10, 10).to(device)

        loss_func = nn.BCEWithLogitsLoss().to(device)
    
        optimizer = optim.Adam(model.parameters(), lr=1e-4)

        scheduler = None
        
        metric = lambda x,y: x+y

        callbacks = [   
                        LossOptimizerHandler(),
                        StatsPrint()
                    ]

        data_loader = eu.examples_data_generator(10, 10, 10)
        data_loader_steps = 100

        # ASSERT BAD VALUE FOR metric
        self.assertRaises(ValueError, Trainer, model=model, 
                                                device=device, 
                                                loss_func=loss_func, 
                                                optimizer=optimizer,
                                                scheduler=scheduler,
                                                metrics=metric, 
                                                train_data_loader=data_loader, 
                                                val_data_loader=data_loader,
                                                train_steps=data_loader_steps,
                                                val_steps=data_loader_steps,
                                                callbacks=callbacks,
                                                name='Trainer-Test')

        # ASSERT GOOD VALUE FOR metrics
        metrics = BinaryAccuracyWithLogits('acc')
        trainer = Trainer(model=model, 
                        device=device, 
                        loss_func=loss_func, 
                        optimizer=optimizer,
                        scheduler=scheduler,
                        metrics=metrics, 
                        train_data_loader=data_loader, 
                        val_data_loader=data_loader,
                        train_steps=data_loader_steps,
                        val_steps=data_loader_steps,
                        callbacks=callbacks,
                        name='Trainer-Test')
Example #6
0
def get_trainer(N, D_in, H, D_out, data_loader, data_loader_steps):

    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.CrossEntropyLoss().to(device)
   
    optimizer = optim.Adam(model.parameters(), lr=0.1)

    # HERE WE USE KerasDecay, IT WILL DECAY THE LEARNING-RATE USING
    # THE FORMULA USED IN KERAS:
    # LR = INIT_LR * (1./(1. + decay * step))
    # NOTICE THAT step CAN BE BATCH/EPOCH
    # WE WILL RUN IT ON EPOCH LEVEL IN THIS EXAMPLE, SO WE EXPECT WITH DECAY = 0.01:
    
    # EPOCH 0 LR: 0.1 <--- THIS IS THE STARTING POINT
    # EPOCH 1 LR: 0.1 * (1./(1. + 0.01 * 1)) = 0.09900990099
    # EPOCH 2 LR: 0.1 * (1./(1. + 0.01 * 2)) = 0.09803921568
    # EPOCH 3 LR: 0.1 * (1./(1. + 0.01 * 3)) = 0.09708737864
    # EPOCH 4 LR: 0.1 * (1./(1. + 0.01 * 4)) = 0.09615384615
    # EPOCH 5 LR: 0.1 * (1./(1. + 0.01 * 5)) = 0.09523809523
    scheduler = KerasDecay(optimizer, decay=0.01, last_step=-1) 
    
    metrics = CategoricalAccuracyWithLogits('acc')

    callbacks = [   
                    LossOptimizerHandler(),
                    SchedulerStep(apply_on_phase=Phase.EPOCH_END,
                                  apply_on_states=State.EXTERNAL,
                                  verbose=1),                        #LET'S PRINT TO SEE THE ACTUAL CHANGES
                    StatsPrint()
                ]

    trainer = Trainer(model=model, 
                      device=device, 
                      loss_func=loss_func, 
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metrics=metrics, 
                      train_data_loader=data_loader, 
                      val_data_loader=data_loader,
                      train_steps=data_loader_steps,
                      val_steps=data_loader_steps,
                      callbacks=callbacks,
                      name='Keras-Decay-Example')
    return trainer
Example #7
0
def get_trainer(N, D_in, H, D_out, data_loader, data_loader_steps):

    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.MSELoss(reduction='sum').to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # LETS ADD A StepLR SCHEDULER
    scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer,
                                          gamma=0.999,
                                          step_size=1)

    metrics = None  # THIS EXAMPLE DOES NOT USE METRICS, ONLY LOSS

    # LETS ADD SchedulerStep WITH apply_on_phase=Phase.BATCH_END
    # AND apply_on_states=State.TRAIN
    # IT MEANS THAT THE SchedulerStep WILL BE INVOKED AT THE END OF EVERY BATCH, BUT, WILL ONLY BE APPLIED WHEN
    # IN TRAIN MODE, AND WILL BE IGNORED IN VAL/TEST MODES
    # NOTICE!!! WE USE verbose=1 TO SEE THE PRINTS FOR THIS EXAMPLE, BUT YOU MIGHT PREFER TO USE verbose=0 or verbose=2
    # BECAUSE ON BATCH LEVEL IT WILL PRINT A LOT
    callbacks = [
        LossOptimizerHandler(),
        SchedulerStep(
            apply_on_phase=Phase.BATCH_END,
            apply_on_states=State.TRAIN,
            verbose=1
        ),  #CAN ALSO BE IN FORM OF ARRAY - apply_on_states=[State.TRAIN]
        StatsPrint(apply_on_phase=Phase.EPOCH_END)
    ]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metrics=metrics,
                      train_data_loader=data_loader,
                      val_data_loader=data_loader,
                      train_steps=data_loader_steps,
                      val_steps=data_loader_steps,
                      callbacks=callbacks,
                      name='Scheduler-Step-On-Batch-Example')
    return trainer
Example #8
0
def get_trainer(params):

    device = tu.get_gpu_device_if_available()

    # Use the nn package to define our model and loss function.
    model = Model(params['H'], params['D_out'], num_embeddings,
                  params['embedding_dim']).to(device)

    loss_func = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.1)

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

    metrics = BinaryAccuracyWithLogits(name='acc')

    callbacks = [
        LossOptimizerHandler(),
        SchedulerStep(apply_on_phase=Phase.BATCH_END,
                      apply_on_states=State.TRAIN),
        EarlyStopping(
            callback_monitor=CallbackMonitor(monitor_type=MonitorType.LOSS,
                                             stats_type=StatsType.VAL,
                                             patience=3,
                                             monitor_mode=MonitorMode.MIN)),
        StatsPrint(round_values_on_print_to=7)
    ]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metrics=metrics,
                      train_data_loader=train_data_loader,
                      val_data_loader=val_data_loader,
                      train_steps=len(train_dataset),
                      val_steps=len(val_dataset),
                      callbacks=callbacks,
                      name='DataLoader-Example')
    return trainer
Example #9
0
def get_trainer_base(D_in, H, D_out):
    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    scheduler = DoNothingToLR(
    )  #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT

    metrics = [
        BinaryAccuracyWithLogits(name='Accuracy'),
        InaccuracyWithLogits(name='InAccuracy'),
        TruePositives(num_classes=2, threshold=0.0),
        TrueNegatives(num_classes=2, threshold=0.0),
        Truthfulness(name='Truthfulness')
    ]

    return device, model, loss_func, optimizer, scheduler, metrics
Example #10
0
def get_trainer_base(D_in, H, D_out, num_classes):
    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.CrossEntropyLoss().to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    scheduler = DoNothingToLR(
    )  #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT

    labels = ['Cat', 'Dog', 'Bird']
    metrics = [
        TruePositives(num_classes, labels=labels, threshold=0),
        FalsePositives(num_classes, labels=labels, threshold=0),
        TrueNegatives(num_classes, labels=labels, threshold=0),
        FalseNegatives(num_classes, labels=labels, threshold=0)
    ]

    return device, model, loss_func, optimizer, scheduler, metrics
Example #11
0
    def test_loss_handler_validation(self):
        device = tu.get_gpu_device_if_available()

        model = eu.get_basic_model(10, 10, 10).to(device)

        loss_func = nn.BCEWithLogitsLoss().to(device)
    
        optimizer = optim.Adam(model.parameters(), lr=1e-4)

        scheduler = KerasDecay(optimizer, 0.0001, last_step=-1)
        
        metrics = BinaryAccuracyWithLogits(name='acc')

        callbacks = [   
                        StatsPrint()
                    ]

        
        data_loader = eu.examples_data_generator(10, 10, 10)
        data_loader_steps = 100
        num_epochs = 5
        verbose = 0
        
        trainer = Trainer(model=model, 
                        device=device, 
                        loss_func=loss_func, 
                        optimizer=optimizer,
                        scheduler=scheduler,
                        metrics=metrics, 
                        train_data_loader=data_loader, 
                        val_data_loader=data_loader,
                        train_steps=data_loader_steps,
                        val_steps=data_loader_steps,
                        callbacks=callbacks,
                        name='Trainer-Test')
        
        self.assertRaises(ValueError, trainer.train, num_epochs, verbose)

        trainer.callbacks.append(LossOptimizerHandler())
        trainer.train(num_epochs, verbose=0)
Example #12
0
def get_trainer(N, D_in, H, D_out, data_loader, data_loader_steps, mode):

    device = tu.get_gpu_device_if_available()

    model = eu.get_basic_model(D_in, H, D_out).to(device)

    loss_func = nn.MSELoss(reduction='sum').to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    metrics = None  # THIS EXAMPLE DOES NOT USE METRICS, ONLY LOSS

    if mode == 'LossOptimizerHandlerAccumulateBatches':
        callbacks = [
            LossOptimizerHandlerAccumulateBatches(
                min_num_batchs_before_backprop=10, verbose=1),
        ]
    if mode == 'LossOptimizerHandlerAccumulateSamples':
        callbacks = [
            LossOptimizerHandlerAccumulateSamples(
                min_num_samples_before_backprop=20, verbose=1),
        ]

    trainer = Trainer(
        model=model,
        device=device,
        loss_func=loss_func,
        optimizer=optimizer,
        scheduler=None,
        metrics=metrics,
        train_data_loader=data_loader,  # DATA LOADER WILL YIELD BATCH SIZE OF 1
        val_data_loader=data_loader,
        train_steps=data_loader_steps,
        val_steps=data_loader_steps,
        callbacks=callbacks,
        name='Accumulate-Grads-Example')
    return trainer
Example #13
0
def get_trainer(D_in, H, D_out, data_loader, data_loader_steps, num_epochs):
    device = tu.get_gpu_device_if_available()

    # Use the nn package to define our model and loss function.
    model = nn.Sequential(Dense(D_in, H, use_bias=True, activation=F.relu),
                          Dense(H, D_out, use_bias=True,
                                activation=None)).to(device)

    loss_func = nn.MSELoss(reduction='sum')

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    scheduler = DoNothingToLR(
        optimizer=optimizer
    )  #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT

    metric_name_to_func = None  # THIS EXAMPLE DOES NOT USE METRICS, ONLY LOSS

    callbacks = [
        SchedulerStep(),
        EpochEndStats(cb_phase=cbs.CB_ON_EPOCH_END, round_values_on_print_to=7)
    ]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metric_name_to_func=metric_name_to_func,
                      train_data_loader=data_loader,
                      val_data_loader=data_loader,
                      train_steps=data_loader_steps,
                      val_steps=data_loader_steps,
                      num_epochs=num_epochs,
                      callbacks=callbacks,
                      name='Basic-Example')
    return trainer
Example #14
0
def get_trainer(config, num_embeddings, train_data_loader, val_data_loader,
                train_steps, val_steps, checkpoint_dir, checkpoint_file_name,
                summary_writer_dir):
    device = tu.get_gpu_device_if_available()

    model = TestModel(config, num_embeddings).to(device)

    optimizer = optim.SGD(params=model.parameters(),
                          lr=config.LEARNING_RATE,
                          momentum=0.9)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        patience=config.EARLY_STOPPING_PATIENCE // 2,
        verbose=True
    )  # needs SchedulerStep callback WITH scheduler_parameters_func

    loss_func = nn.BCEWithLogitsLoss().to(device)

    metrics = [
        BinaryAccuracyWithLogits(name='Accuracy'),
        TruePositives(num_classes=2, threshold=0, name='TP')
    ]

    callbacks = [
        LossOptimizerHandler(),
        SchedulerStep(scheduler_parameters_func=lambda callback_context:
                      callback_context.val_stats.get_loss()),
        Tensorboard(summary_writer_dir=summary_writer_dir),
        EarlyStopping(apply_on_phase=Phase.EPOCH_END,
                      apply_on_states=State.EXTERNAL,
                      callback_monitor=CallbackMonitor(
                          monitor_type=MonitorType.LOSS,
                          stats_type=StatsType.VAL,
                          monitor_mode=MonitorMode.MIN,
                          patience=config.EARLY_STOPPING_PATIENCE)),
        StatsPrint(apply_on_phase=Phase.EPOCH_END,
                   round_values_on_print_to=7,
                   print_confusion_matrix_normalized=True),
        ModelCheckPoint(
            checkpoint_dir=checkpoint_dir,
            checkpoint_file_name=checkpoint_file_name,
            callback_monitor=CallbackMonitor(monitor_type=MonitorType.LOSS,
                                             stats_type=StatsType.VAL,
                                             monitor_mode=MonitorMode.MIN),
            save_best_only=True,
            round_values_on_print_to=7
        ),  # BETTER TO PUT ModelCheckPoint LAST (SO IN CASE IT SAVES, THE STATES OF ALL THE CALLBACKS WILL BE UP TO DATE)
    ]

    trainer = Trainer(model=model,
                      device=device,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metrics=metrics,
                      train_data_loader=train_data_loader,
                      val_data_loader=val_data_loader,
                      train_steps=train_steps,
                      val_steps=val_steps,
                      callbacks=callbacks,
                      name='Multi-Input-Example')
    return trainer
Example #15
0
    def test_save_and_load(self):
        gu.seed_all(42)
        save_to_dir = os.path.dirname(__file__) + '/trainer_checkpoint/'
        trainer_file_name = 'trainer'

        device = tu.get_gpu_device_if_available()

        model = eu.get_basic_model(10, 10, 10).to(device)

        loss_func = nn.CrossEntropyLoss().to(device)
    
        optimizer = optim.Adam(model.parameters(), lr=1e-4)

        scheduler = KerasDecay(optimizer, 0.0001, last_step=-1)
        
        metrics = CategoricalAccuracyWithLogits(name='acc')

        callbacks = [   
                        LossOptimizerHandler(),
                        ModelCheckPoint(checkpoint_dir=save_to_dir, 
                                        checkpoint_file_name=trainer_file_name, 
                                        callback_monitor=CallbackMonitor(monitor_type=MonitorType.LOSS, 
                                                                         stats_type=StatsType.VAL, 
                                                                         monitor_mode=MonitorMode.MIN),
                                        save_best_only=False, 
                                        save_full_trainer=True,
                                        verbose=0),
                        SchedulerStep(apply_on_phase=Phase.BATCH_END, apply_on_states=State.TRAIN),
                        StatsPrint()
                    ]

        
        data_loader = eu.examples_data_generator(10, 10, 10, category_out=True)
        data_loader_steps = 100
        num_epochs = 5

        trainer = Trainer(model=model, 
                        device=device, 
                        loss_func=loss_func, 
                        optimizer=optimizer,
                        scheduler=scheduler,
                        metrics=metrics, 
                        train_data_loader=data_loader, 
                        val_data_loader=data_loader,
                        train_steps=data_loader_steps,
                        val_steps=data_loader_steps,
                        callbacks=callbacks,
                        name='Trainer-Test')
        
        trainer.train(num_epochs, verbose=0)

        loaded_trainer = Trainer.load_trainer(dir_path=save_to_dir,
                                            file_name=trainer_file_name + f'_epoch_{num_epochs}',
                                            model=model,
                                            device=device,
                                            loss_func=loss_func,
                                            optimizer=optimizer,
                                            scheduler=scheduler,
                                            train_data_loader=data_loader, 
                                            val_data_loader=data_loader,
                                            train_steps=data_loader_steps,
                                            val_steps=data_loader_steps)
        
        self.assertEqual(loaded_trainer.epoch, trainer.epoch)
        self.assertListEqual(tu.get_lrs_from_optimizer(loaded_trainer.optimizer), tu.get_lrs_from_optimizer(trainer.optimizer))
        self.assertEqual(loaded_trainer.callbacks[1].monitor._get_best(), trainer.callbacks[1].monitor._get_best())
Example #16
0
    def test_save_and_predict(self):

        save_to_dir = os.path.dirname(__file__) + '/trainer_checkpoint/'
        checkpoint_file_name = 'checkpoint'
        trainer_file_name = 'trainer'

        device = tu.get_gpu_device_if_available()

        model = TestModel().to(device)

        loss_func = nn.BCEWithLogitsLoss().to(device)

        optimizer = optim.Adam(model.parameters(), lr=1e-4)

        scheduler = None

        metrics = BinaryAccuracyWithLogits(name='acc')

        callbacks = [
            LossOptimizerHandler(),
            ModelCheckPoint(checkpoint_dir=save_to_dir,
                            checkpoint_file_name=checkpoint_file_name,
                            callback_monitor=CallbackMonitor(
                                monitor_type=MonitorType.LOSS,
                                stats_type=StatsType.VAL,
                                monitor_mode=MonitorMode.MIN),
                            save_best_only=True,
                            save_full_trainer=False),
        ]

        data_loader = data_generator()
        data_loader_steps = 100
        num_epochs = 5

        trainer = Trainer(model=model,
                          device=device,
                          loss_func=loss_func,
                          optimizer=optimizer,
                          scheduler=scheduler,
                          metrics=metrics,
                          train_data_loader=data_loader,
                          val_data_loader=data_loader,
                          train_steps=data_loader_steps,
                          val_steps=data_loader_steps,
                          callbacks=callbacks,
                          name='Predictor-Trainer-Test')

        x1_x2, y = next(data_loader)
        _ = trainer.predict_batch(x1_x2)  # JUST TO CHECK THAT IT FUNCTIONS

        sample = [x1_x2[0][0], x1_x2[1][0]]

        # PREDICT BEFORE TRAIN
        sample_prediction_before_train = trainer.predict_sample(sample)

        trainer.train(num_epochs, verbose=0)

        # PREDICT AFTER TRAIN
        sample_prediction_from_trainer = trainer.predict_sample(sample)

        # SAVE THE TRAINER
        trainer.save_trainer(save_to_dir, trainer_file_name)

        #-----------------------------------------------#
        # CREATE PREDICTOR FROM CURRENT TRAINER
        #-----------------------------------------------#
        predictor_from_trainer = Predictor.from_trainer(trainer)

        # PREDICT FROM PREDICTOR
        sample_prediction_from_predictor = predictor_from_trainer.predict_sample(
            sample)

        self.assertFalse(
            (sample_prediction_before_train == sample_prediction_from_trainer
             ).all())
        self.assertTrue(
            (sample_prediction_from_predictor == sample_prediction_from_trainer
             ).all())

        #-----------------------------------------------#
        # LOAD MODEL CHECKPOINT AS NEW PREDICTOR
        #-----------------------------------------------#
        fresh_device = tu.get_gpu_device_if_available()
        fresh_model = TestModel().to(fresh_device)
        loaded_predictor = Predictor.from_checkpoint(
            save_to_dir, checkpoint_file_name + '_best_only', fresh_model,
            fresh_device)

        # PREDICT AFTER LOAD
        sample_prediction_from_loaded_predictor = loaded_predictor.predict_sample(
            sample)

        self.assertFalse(
            (sample_prediction_before_train == sample_prediction_from_trainer
             ).all())
        self.assertTrue((sample_prediction_from_loaded_predictor ==
                         sample_prediction_from_trainer).all())

        #-----------------------------------------------#
        # LOAD TRAINER CHECKPOINT AS NEW PREDICTOR
        #-----------------------------------------------#
        fresh_device = tu.get_gpu_device_if_available()
        fresh_model = TestModel().to(fresh_device)
        loaded_predictor = Predictor.from_checkpoint(save_to_dir,
                                                     trainer_file_name,
                                                     fresh_model, fresh_device)

        # PREDICT AFTER LOAD
        sample_prediction_from_loaded_predictor = loaded_predictor.predict_sample(
            sample)

        self.assertFalse(
            (sample_prediction_before_train == sample_prediction_from_trainer
             ).all())
        self.assertTrue((sample_prediction_from_loaded_predictor ==
                         sample_prediction_from_trainer).all())