コード例 #1
0
def test_dimension_methods():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    assert 8936 == movieLensDataHandler.get_input_dim()
    assert 8936 == movieLensDataHandler.get_output_dim()
コード例 #2
0
def test_droplast():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    train_dataloader = movieLensDataHandler.get_train_dataloader(
        batch_size=200, drop_last=False)
    count = 0
    for batch in train_dataloader:
        assert 8936 == len(batch[0][0])
        assert 8936 == len(batch[1][0])
        count += 1
    assert 51 == count
コード例 #3
0
def test_batchsize():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    # test the number of batches
    train_dataloader = movieLensDataHandler.get_train_dataloader(
        batch_size=200)
    count = 0
    for batch in train_dataloader:
        assert 200 == len(batch[0])
        assert 200 == len(batch[1])
        assert 8936 == len(batch[0][0])
        assert 8936 == len(batch[1][0])
        count += 1
    assert 50 == count
コード例 #4
0
def test_length():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)
    assert 10001 == movieLensDataHandler.get_traindata_len()
    assert 2000 == movieLensDataHandler.get_testdata_len()
    assert 2000 == movieLensDataHandler.get_validationdata_len()
    assert 10001 == movieLensDataHandler.get_traindata_len()
    assert 2000 == movieLensDataHandler.get_testdata_len()
    assert 2000 == movieLensDataHandler.get_validationdata_len()
コード例 #5
0
def test_shuffle():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    # test the number of batches
    train_dataloader = movieLensDataHandler.get_train_dataloader(shuffle=False)

    first = True
    first_batch = None
    for batch in train_dataloader:
        if first:
            first_batch = batch
            first = False

    first = True
    for batch in train_dataloader:
        if first:
            comparison = batch[0] == first_batch[0]
            assert comparison.all()
            break
コード例 #6
0
logger.setLevel(logging.INFO)

# set cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data_path = os.path.join(dir_path, 'movielens_small_training.npy')
validation_input_data_path = os.path.join(
    dir_path, 'movielens_small_validation_input.npy')
validation_output_data_path = os.path.join(
    dir_path, 'movielens_small_validation_test.npy')
test_input_data_path = os.path.join(dir_path, 'movielens_small_test_input.npy')
test_output_data_path = os.path.join(dir_path, 'movielens_small_test_test.npy')
products_data_path = os.path.join(dir_path, 'movielens_products_data.npy')

data_handler = AEDataHandler('MovieLensSmall', train_data_path,
                             validation_input_data_path,
                             validation_output_data_path, test_input_data_path,
                             test_output_data_path)

input_dim = data_handler.get_input_dim()
output_dim = data_handler.get_output_dim()

products_data_np = np.load(products_data_path)
products_data_torch = torch.tensor(products_data_np,
                                   dtype=torch.float32).to(device)

# create model
model = MultiVAE(params='yaml_files/params_multi_VAE_training.yaml')

correctness_loss = VAELoss()
revenue_loss = VAELoss(weighted_vector=products_data_torch)
losses = [correctness_loss, revenue_loss]
コード例 #7
0
    dir_path, 'movielens_small_validation_test.npy')
test_input_data_path = os.path.join(dir_path, 'movielens_small_test_input.npy')
test_output_data_path = os.path.join(dir_path, 'movielens_small_test_test.npy')
products_data_path = os.path.join(dir_path, 'movielens_products_data.npy')

np.save(train_data_path, np.random.rand(10000, 8936).astype('float32'))
np.save(validation_input_data_path,
        np.random.rand(2000, 8936).astype('float32'))
np.save(validation_output_data_path,
        np.random.rand(2000, 8936).astype('float32'))
np.save(test_input_data_path, np.random.rand(2000, 8936).astype('float32'))
np.save(test_output_data_path, np.random.rand(2000, 8936).astype('float32'))
np.save(products_data_path, np.random.rand(8936))

dataHandler = AEDataHandler('Testing trainer random dataset', train_data_path,
                            validation_input_data_path,
                            validation_output_data_path, test_input_data_path,
                            test_output_data_path)

input_dim = dataHandler.get_input_dim()
output_dim = dataHandler.get_output_dim()

products_data_np = np.load(products_data_path)
products_data_torch = torch.tensor(products_data_np,
                                   dtype=torch.float32).to(device)

# create model
model = MultiVAE(params='yaml_files/params_multi_VAE.yaml')

correctness_loss = VAELoss()
revenue_loss = VAELoss(weighted_vector=products_data_torch)
losses = [correctness_loss, revenue_loss]