def test_droplast():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    train_dataloader = movieLensDataHandler.get_train_dataloader(
        batch_size=200, drop_last=False)
    count = 0
    for batch in train_dataloader:
        assert 8936 == len(batch[0][0])
        assert 8936 == len(batch[1][0])
        count += 1
    assert 51 == count
def test_batchsize():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    # test the number of batches
    train_dataloader = movieLensDataHandler.get_train_dataloader(
        batch_size=200)
    count = 0
    for batch in train_dataloader:
        assert 200 == len(batch[0])
        assert 200 == len(batch[1])
        assert 8936 == len(batch[0][0])
        assert 8936 == len(batch[1][0])
        count += 1
    assert 50 == count
def test_shuffle():
    movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path,
                                         validation_input_data_path,
                                         validation_output_data_path,
                                         test_input_data_path,
                                         test_output_data_path)

    # test the number of batches
    train_dataloader = movieLensDataHandler.get_train_dataloader(shuffle=False)

    first = True
    first_batch = None
    for batch in train_dataloader:
        if first:
            first_batch = batch
            first = False

    first = True
    for batch in train_dataloader:
        if first:
            comparison = batch[0] == first_batch[0]
            assert comparison.all()
            break