def main(): print('building vocabulary...') voc = Voc() print('done') print('loading data and building batches...') data_set = FruitSeqDataset(voc, dataset_file_path=DATA_FILE, batch_size=1) str_set = data_set.load_stringset(DATA_FILE) print('done') print('rebuilding model from saved parameters in ' + args.param_file + '...') model = Set2Seq2Seq(voc.num_words).to(args.device) checkpoint = torch.load(args.param_file, map_location=args.device) train_args = checkpoint['args'] model.load_state_dict(checkpoint['model']) voc = checkpoint['voc'] print('done') model.eval() print('iterating data set...') out_file = open(OUT_FILE, mode='a') iterate_dataset(model, voc, str_set, data_set, out_file, train_args)
def get_batches4sim_check(voc, dataset_file_path=args.data_file): in_set = FruitSeqDataset.load_stringset(dataset_file_path) batch_set = ChooseDataset(voc, batch_size=1, dataset_file_path=dataset_file_path) return in_set, batch_set
def main( model_name='Img2Seq2Choice', dataset_name='ImgChooseDataset', out_file_path='data/tmp.txt', ): if args.param_file is not None: checkpoint = torch.load(args.param_file, map_location=torch.device('cpu')) else: raise ValueError print('rebuilding vocabulary and model...') voc = checkpoint[ 'voc'] if model_name == 'Set2Seq2Seq' or model_name == 'Set2Seq2Choice' else None train_args = checkpoint['args'] print(train_args) if model_name == 'Img2Seq2Choice': model = Img2Seq2Choice(msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) elif model_name == 'Set2Seq2Seq': model = Set2Seq2Seq(voc.num_words, msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) elif model_name == 'Set2Seq2Choice': model = Set2Seq2Choice(voc.num_words, msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) else: raise NotImplementedError model.load_state_dict(checkpoint['model']) model.eval() print('done') print('loading and building batch dataset...') if dataset_name == 'ImgChooseDataset': batch_set = ImgChooseDataset(dataset_dir_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = [batch['correct']['label'][0] for batch in batch_set] elif dataset_name == 'FruitSeqDataset': batch_set = FruitSeqDataset(voc, dataset_file_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = FruitSeqDataset.load_stringset(args.data_file) elif dataset_name == 'ChooseDataset': batch_set = ChooseDataset(voc, dataset_file_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = FruitSeqDataset.load_stringset(args.data_file) print('done') build_listener_training_file(model, in_set, batch_set, out_file_path)