def test_droplast(): movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path, validation_input_data_path, validation_output_data_path, test_input_data_path, test_output_data_path) train_dataloader = movieLensDataHandler.get_train_dataloader( batch_size=200, drop_last=False) count = 0 for batch in train_dataloader: assert 8936 == len(batch[0][0]) assert 8936 == len(batch[1][0]) count += 1 assert 51 == count
def test_batchsize(): movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path, validation_input_data_path, validation_output_data_path, test_input_data_path, test_output_data_path) # test the number of batches train_dataloader = movieLensDataHandler.get_train_dataloader( batch_size=200) count = 0 for batch in train_dataloader: assert 200 == len(batch[0]) assert 200 == len(batch[1]) assert 8936 == len(batch[0][0]) assert 8936 == len(batch[1][0]) count += 1 assert 50 == count
def test_shuffle(): movieLensDataHandler = AEDataHandler('MovieLensSmall', train_data_path, validation_input_data_path, validation_output_data_path, test_input_data_path, test_output_data_path) # test the number of batches train_dataloader = movieLensDataHandler.get_train_dataloader(shuffle=False) first = True first_batch = None for batch in train_dataloader: if first: first_batch = batch first = False first = True for batch in train_dataloader: if first: comparison = batch[0] == first_batch[0] assert comparison.all() break