def get_dataloader(self): """ Convert self.datasets (pands.DataFrame) to pytorch DataLoaders """ batch_size = self.config[HParamKey.BATCH_SIZE] # Convert to indices based on FastText vocabulary train_prem, train_hypo, train_label = snli_token2id( self.datasets[lType.TRAIN], self.word2idx) val_prem, val_hypo, val_label = snli_token2id(self.datasets[lType.VAL], self.word2idx) logger.info("Converted to indices! ") # Create DataLoader logger.info("Creating DataLoader...") train_dataset = snliDataset(train_prem, train_hypo, train_label) self.loaders[lType.TRAIN] = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, collate_fn=snli_collate_func, shuffle=True) val_dataset = snliDataset(val_prem, val_hypo, val_label) self.loaders[lType.VAL] = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=batch_size, collate_fn=snli_collate_func, shuffle=True) logger.info("DataLoader generated!")
logging.config.dictConfig(LogConfig) logger.info("START JOB NLI on device ({})".format(DEVICE)) # Load pre-trained embeddings of FastText # word2idx, idx2word, ft_embs = get_fasttext_embedding(50000, 'cc') logger.info("Pre-trained embeddings loaded!") # logger.info("\n===== word2idx ======\n{}\n=====================".format(word2idx)) # Load train/validation sets train_set, val_set = get_snli_data() logger.info( "\n===== train/validation sets =====\nTrain sample: {}\nValidation sample: {}" .format(len(train_set), len(val_set))) # Convert to indices based on FastText vocabulary train_prem, train_hypo, train_label = snli_token2id(train_set, word2idx) val_prem, val_hypo, val_label = snli_token2id(val_set, word2idx) logger.info("Converted to indices! ") # Create DataLoader logger.info("Creating DataLoader...") BATCH_SIZE = 32 train_dataset = snliDataset(train_prem, train_hypo, train_label) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, collate_fn=snli_collate_func, shuffle=True) val_dataset = snliDataset(val_prem, val_hypo, val_label) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, collate_fn=snli_collate_func,