def main(): parser = options.get_parser('Generator') options.add_dataset_args(parser) options.add_preprocessing_args(parser) options.add_model_args(parser) options.add_optimization_args(parser) options.add_checkpoint_args(parser) options.add_generation_args(parser) args = parser.parse_args() print(args) args.cuda = not args.disable_cuda and torch.cuda.is_available() caseless = args.caseless batch_size = args.batch_size if os.path.isfile(args.load_checkpoint): print('Loading checkpoint file from {}...'.format(args.load_checkpoint)) checkpoint_file = torch.load(args.load_checkpoint) else: print('No checkpoint file found: {}'.format(args.load_checkpoint)) raise OSError train_raw_corpus, val_raw_corpus, test_raw_corpus = utils.load_corpus(args.processed_dir, ddi=True) test_corpus = [(line.sent, line.type, line.p1, line.p2) for line in test_raw_corpus] # preprocessing feature_map = checkpoint_file['f_map'] target_map = checkpoint_file['t_map'] test_features, test_targets = utils.build_corpus(test_corpus, feature_map, target_map, caseless) # train/val split test_loader = utils.construct_bucket_dataloader(test_features, test_targets, feature_map['PAD'], batch_size, args.position_bound, is_train=False) # build model vocab_size = len(feature_map) tagset_size = len(target_map) model = utils.build_model(args, vocab_size, tagset_size) # loss criterion = utils.build_loss(args) # load states model.load_state_dict(checkpoint_file['state_dict']) # trainer trainer = SeqTrainer(args, model, criterion) if args.cuda: model.cuda() y_true, y_pred, att_weights = predict(trainer, test_loader, target_map, cuda=args.cuda) assert len(y_pred) == len(test_corpus), 'length of prediction is inconsistent with that of data set' # prediction print('Predicting...') assert len(y_pred) == len(test_corpus), 'length of prediction is inconsistent with that of data set' # write result: sent_id|e1|e2|ddi|type with open(args.predict_file, 'w') as f: for tup, pred in zip(test_raw_corpus, y_pred): ddi = 0 if pred == 'null' else 1 f.write('|'.join([tup.sent_id, tup.e1, tup.e2, str(ddi), pred])) f.write('\n') # error analysis print('Analyzing...') with open(args.error_file, 'w') as f: f.write(' | '.join(['sent_id', 'e1', 'e2', 'target', 'pred'])) f.write('\n') for tup, target, pred, att_weight in zip(test_raw_corpus, y_true, y_pred, att_weights): if target != pred: size = len(tup.sent) f.write('{}\n'.format(' '.join(tup.sent))) if args.model != 'InterAttentionLSTM': att_weight = [att_weight] for i in range(len(att_weight)): f.write('{}\n'.format(' '.join(map(lambda x: str(round(x, 4)), att_weight[i][:size])))) f.write('{}\n\n'.format(' | '.join([tup.sent_id, tup.e1, tup.e2, target, pred]))) # attention print('Writing attention scores...') with open(args.att_file, 'w') as f: f.write(' | '.join(['target', 'sent', 'att_weight'])) f.write('\n') for tup, target, pred, att_weight in zip(test_raw_corpus, y_true, y_pred, att_weights): if target == pred and target != 'null': size = len(tup.sent) f.write('{}\n'.format(target)) f.write('{}\n'.format(' '.join(tup.sent))) if args.model != 'InterAttentionLSTM': att_weight = [att_weight] for i in range(len(att_weight)): f.write('{}\n'.format(' '.join(map(lambda x: str(round(x, 4)), att_weight[i][:size]))))
#python3 generate.py --data data-bin/UN/million6way/en-fr/bpe/preprocessed/ --src_lang en --trg_lang fr --batch-size 80 --gpuid 0 --model-dir data-bin/UN/million6way/en-fr/checkpoints/lr34clip1/best_gmodel.pt logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) parser = argparse.ArgumentParser( description="Driver program for JHU Adversarial-NMT.") # Load args options.add_general_args(parser) options.add_dataset_args(parser) options.add_checkpoint_args(parser) options.add_distributed_training_args(parser) options.add_generation_args(parser) options.add_generator_model_args(parser) def main(args): use_cuda = (len(args.gpuid) >= 1) if args.gpuid: cuda.set_device(args.gpuid[0]) print(args.replace_unk) #None # Load dataset if args.replace_unk is None: dataset = data.load_dataset( args.data, ['test'], args.src_lang,
def main(): parser = options.get_parser('Generator') options.add_dataset_args(parser) options.add_preprocessing_args(parser) options.add_model_args(parser) options.add_optimization_args(parser) options.add_checkpoint_args(parser) options.add_generation_args(parser) args = parser.parse_args() model_path = args.load_checkpoint + '.model' args_path = args.load_checkpoint + '.json' with open(args_path, 'r') as f: _args = json.load(f)['args'] [setattr(args, k, v) for k, v in _args.items()] args.cuda = not args.disable_cuda and torch.cuda.is_available() print(args) if args.cuda: torch.backends.cudnn.benchmark = True # increase recursion depth sys.setrecursionlimit(10000) # load dataset train_raw_corpus, val_raw_corpus, test_raw_corpus = utils.load_corpus( args.processed_dir, ddi=False) assert train_raw_corpus and val_raw_corpus and test_raw_corpus, 'Corpus not found, please run preprocess.py to obtain corpus!' train_corpus = [(line.sent, line.type, line.p1, line.p2) for line in train_raw_corpus] val_corpus = [(line.sent, line.type, line.p1, line.p2) for line in val_raw_corpus] caseless = args.caseless batch_size = args.batch_size # build vocab sents = [tup[0] for tup in train_corpus + val_corpus] feature_map = utils.build_vocab(sents, min_count=args.min_count, caseless=caseless) target_map = ddi2013.target_map # get class weights _, train_targets = utils.build_corpus(train_corpus, feature_map, target_map, caseless) class_weights = torch.Tensor( utils.get_class_weights(train_targets)) if args.class_weight else None # load dataets _, _, test_loader = utils.load_datasets(args.processed_dir, args.train_size, args, feature_map, dataloader=True) # build model vocab_size = len(feature_map) tagset_size = len(target_map) model = RelationTreeModel(vocab_size, tagset_size, args) # loss criterion = utils.build_loss(args, class_weights=class_weights) # load states assert os.path.isfile(model_path), "Checkpoint not found!" print('Loading checkpoint file from {}...'.format(model_path)) checkpoint_file = torch.load(model_path) model.load_state_dict(checkpoint_file['state_dict']) # trainer trainer = TreeTrainer(args, model, criterion) # predict y_true, y_pred, treelists, f1_by_len = predict(trainer, test_loader, target_map, cuda=args.cuda) # assign words to roots for tup, treelist in zip(test_raw_corpus, treelists): for t in treelist: t.idx = tup.sent[t.idx] if t.idx < len(tup.sent) else None # prediction print('Predicting...') # write result: sent_id|e1|e2|ddi|type with open(args.predict_file, 'w') as f: for tup, pred in zip(test_raw_corpus, y_pred): ddi = 0 if pred == 'null' else 1 f.write('|'.join([tup.sent_id, tup.e1, tup.e2, str(ddi), pred])) f.write('\n') def print_info(f, tup, target, pred, root): f.write('{}\n'.format(' '.join(tup.sent))) f.write('{}\n'.format(' | '.join( [tup.sent_id, tup.e1, tup.e2, target, pred]))) f.write('{}\n\n'.format(root)) # error analysis print('Analyzing...') with open(args.error_file, 'w') as f: f.write(' | '.join(['sent_id', 'e1', 'e2', 'target', 'pred'])) f.write('\n') for tup, target, pred, treelist in zip(test_raw_corpus, y_true, y_pred, treelists): if target != pred: print_info(f, tup, target, pred, treelist[-1]) # attention print('Writing attention scores...') with open(args.correct_file, 'w') as f: f.write(' | '.join(['target', 'sent', 'att_weight'])) f.write('\n') for tup, target, pred, treelist in zip(test_raw_corpus, y_true, y_pred, treelists): if target == pred and target != 'null': print_info(f, tup, target, pred, treelist[-1])