Пример #1
0
def test_check_input6():
    with pytest.raises(ValueError,
                       match='Please make sure that the directory' +
                       ' where you want to save the models is empty!'):
        trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                         validation_metrics, '.', params)
        trainer.train()
Пример #2
0
def test_check_input2():
    with pytest.raises(
            TypeError,
            match='Please check you are using the right model object,' +
            ' or the right order of the attributes!'):
        trainer = SingleObjectiveTrainer(dataHandler, None, correctness_loss,
                                         validation_metrics, save_to_path,
                                         params)
        trainer.train()
Пример #3
0
def test_check_input9():
    with pytest.raises(
            TypeError,
            match=
            'Please make sure that the optimizer is a pytorch Optimizer object!'
    ):
        trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                         validation_metrics, save_to_path,
                                         params, model)
        trainer.train()
Пример #4
0
def test_check_input5():
    with pytest.raises(
            TypeError,
            match='Please check you are using the right metric objects,' +
            ' or the right order of the attributes!'):
        validation_metrics_tmp = validation_metrics.copy()
        validation_metrics_tmp[0] = model
        trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                         validation_metrics_tmp, save_to_path,
                                         params)
        trainer.train()
Пример #5
0
def test_check_input3():
    class TestModel(nn.Module):
        def forward(self):
            return 1

    with pytest.raises(
            TypeError,
            match=
            'Please check if your models has initialize_model\\(\\) method defined!'
    ):
        trainer = SingleObjectiveTrainer(dataHandler, TestModel(),
                                         correctness_loss, validation_metrics,
                                         save_to_path, params)
        trainer.train()
Пример #6
0
def test_init_objects():
    trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                     validation_metrics, save_to_path, params)
    assert type(trainer._train_dataloader) == DataLoader
    assert type(trainer.pareto_manager) == ParetoManager
    assert trainer.pareto_manager.path == save_to_path
    assert type(trainer.validator) == Validator
Пример #7
0
def test_read_yaml_params():
    trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                     validation_metrics, save_to_path, params)
    assert trainer.seed == 42
    assert trainer.learning_rate == 1e-3
    assert trainer.batch_size_training == 500
    assert trainer.shuffle_training is True
    assert trainer.drop_last_batch_training is True
    assert trainer.batch_size_validation == 500
    assert trainer.shuffle_validation is True
    assert trainer.drop_last_batch_validation is False
    assert trainer.number_of_epochs == 50
    assert trainer.anneal is True
    assert trainer.beta_start == 0
    assert trainer.beta_cap == 0.3
    assert trainer.beta_step == 0.3 / 10000
Пример #8
0
def test_check_input8():
    # check for None metrics
    with pytest.raises(ValueError,
                       match='The validation_metrics are None,' +
                       ' please make sure to give valid validation_metrics!'):
        trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                         None, save_to_path, params)
        trainer.train()
    # check if length is at least 1
    validation_metrics_tmp = []
    with pytest.raises(
            ValueError,
            match=
            'Please check you have defined at least one validation metric!'):
        trainer = SingleObjectiveTrainer(dataHandler, model, correctness_loss,
                                         validation_metrics_tmp, save_to_path,
                                         params)
        trainer.train()
Пример #9
0
test_input_data_path = os.path.join(dir_path, 'movielens_small_test_input.npy')
test_output_data_path = os.path.join(dir_path, 'movielens_small_test_test.npy')
products_data_path = os.path.join(dir_path, 'movielens_products_data.npy')

data_handler = AEDataHandler('MovieLensSmall', train_data_path,
                             validation_input_data_path,
                             validation_output_data_path, test_input_data_path,
                             test_output_data_path)

input_dim = data_handler.get_input_dim()
output_dim = data_handler.get_output_dim()

products_data_np = np.load(products_data_path)
products_data_torch = torch.tensor(products_data_np,
                                   dtype=torch.float32).to(device)

# create model
model = MultiVAE(params='yaml_files/params_multi_VAE_training.yaml')

# correctnes loss
loss = VAELoss()

recallAtK = RecallAtK(k=10)
revenueAtK = RevenueAtK(k=10, revenue=products_data_np)
validation_metrics = [recallAtK, revenueAtK]

trainer = SingleObjectiveTrainer(data_handler, model, loss, validation_metrics,
                                 save_to_path)
trainer.train()
print(trainer.pareto_manager._pareto_front)