def main():
    parser = options.get_parser('Generator')
    options.add_dataset_args(parser)
    options.add_preprocessing_args(parser)
    options.add_model_args(parser)
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_generation_args(parser)
    
    args = parser.parse_args()
    print(args)
    
    args.cuda = not args.disable_cuda and torch.cuda.is_available()
    
    
    caseless = args.caseless
    batch_size = args.batch_size
    
    
    if os.path.isfile(args.load_checkpoint):
        print('Loading checkpoint file from {}...'.format(args.load_checkpoint))
        checkpoint_file = torch.load(args.load_checkpoint)
    else:
        print('No checkpoint file found: {}'.format(args.load_checkpoint))
        raise OSError
        
    train_raw_corpus, val_raw_corpus, test_raw_corpus = utils.load_corpus(args.processed_dir, ddi=True)
    test_corpus = [(line.sent, line.type, line.p1, line.p2) for line in test_raw_corpus]
    
    # preprocessing
    feature_map = checkpoint_file['f_map']
    target_map = checkpoint_file['t_map']
    test_features, test_targets = utils.build_corpus(test_corpus, feature_map, target_map, caseless)
    
    # train/val split
    test_loader = utils.construct_bucket_dataloader(test_features, test_targets, feature_map['PAD'], batch_size, args.position_bound, is_train=False)
    
    # build model
    vocab_size = len(feature_map)
    tagset_size = len(target_map)
    model = utils.build_model(args, vocab_size, tagset_size)
    # loss
    criterion = utils.build_loss(args)
    
    # load states
    model.load_state_dict(checkpoint_file['state_dict'])
    
    # trainer
    trainer = SeqTrainer(args, model, criterion)
    
    if args.cuda:
        model.cuda()
    
    y_true, y_pred, att_weights = predict(trainer, test_loader, target_map, cuda=args.cuda)
    assert len(y_pred) == len(test_corpus), 'length of prediction is inconsistent with that of data set'
    # prediction
    print('Predicting...')
    assert len(y_pred) == len(test_corpus), 'length of prediction is inconsistent with that of data set'
    # write result: sent_id|e1|e2|ddi|type
    with open(args.predict_file, 'w') as f:
        for tup, pred in zip(test_raw_corpus, y_pred):
            ddi = 0 if pred == 'null' else 1
            f.write('|'.join([tup.sent_id, tup.e1, tup.e2, str(ddi), pred]))
            f.write('\n')

    # error analysis
    print('Analyzing...')
    with open(args.error_file, 'w') as f:
        f.write(' | '.join(['sent_id', 'e1', 'e2', 'target', 'pred']))
        f.write('\n')
        for tup, target, pred, att_weight in zip(test_raw_corpus, y_true, y_pred, att_weights):
            if target != pred:
                size = len(tup.sent)
                f.write('{}\n'.format(' '.join(tup.sent)))
                if args.model != 'InterAttentionLSTM':
                    att_weight = [att_weight]
                for i in range(len(att_weight)):
                    f.write('{}\n'.format(' '.join(map(lambda x: str(round(x, 4)), att_weight[i][:size]))))
                f.write('{}\n\n'.format(' | '.join([tup.sent_id, tup.e1, tup.e2, target, pred])))
            
    # attention
    print('Writing attention scores...')
    with open(args.att_file, 'w') as f:
        f.write(' | '.join(['target', 'sent', 'att_weight']))
        f.write('\n')
        for tup, target, pred, att_weight in zip(test_raw_corpus, y_true, y_pred, att_weights):
            if target == pred and target != 'null':
                size = len(tup.sent)
                f.write('{}\n'.format(target))
                f.write('{}\n'.format(' '.join(tup.sent)))
                if args.model != 'InterAttentionLSTM':
                    att_weight = [att_weight]
                for i in range(len(att_weight)):
                    f.write('{}\n'.format(' '.join(map(lambda x: str(round(x, 4)), att_weight[i][:size]))))
Example #2
0
#python3 generate.py --data data-bin/UN/million6way/en-fr/bpe/preprocessed/  --src_lang en --trg_lang fr  --batch-size 80 --gpuid 0 --model-dir data-bin/UN/million6way/en-fr/checkpoints/lr34clip1/best_gmodel.pt

logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.DEBUG)

parser = argparse.ArgumentParser(
    description="Driver program for JHU Adversarial-NMT.")

# Load args
options.add_general_args(parser)
options.add_dataset_args(parser)
options.add_checkpoint_args(parser)
options.add_distributed_training_args(parser)
options.add_generation_args(parser)
options.add_generator_model_args(parser)


def main(args):

    use_cuda = (len(args.gpuid) >= 1)
    if args.gpuid:
        cuda.set_device(args.gpuid[0])
        print(args.replace_unk)  #None
        # Load dataset
        if args.replace_unk is None:
            dataset = data.load_dataset(
                args.data,
                ['test'],
                args.src_lang,
def main():
    parser = options.get_parser('Generator')
    options.add_dataset_args(parser)
    options.add_preprocessing_args(parser)
    options.add_model_args(parser)
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_generation_args(parser)

    args = parser.parse_args()

    model_path = args.load_checkpoint + '.model'
    args_path = args.load_checkpoint + '.json'
    with open(args_path, 'r') as f:
        _args = json.load(f)['args']
    [setattr(args, k, v) for k, v in _args.items()]

    args.cuda = not args.disable_cuda and torch.cuda.is_available()

    print(args)

    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # increase recursion depth
    sys.setrecursionlimit(10000)

    # load dataset
    train_raw_corpus, val_raw_corpus, test_raw_corpus = utils.load_corpus(
        args.processed_dir, ddi=False)
    assert train_raw_corpus and val_raw_corpus and test_raw_corpus, 'Corpus not found, please run preprocess.py to obtain corpus!'
    train_corpus = [(line.sent, line.type, line.p1, line.p2)
                    for line in train_raw_corpus]
    val_corpus = [(line.sent, line.type, line.p1, line.p2)
                  for line in val_raw_corpus]

    caseless = args.caseless
    batch_size = args.batch_size

    # build vocab
    sents = [tup[0] for tup in train_corpus + val_corpus]
    feature_map = utils.build_vocab(sents,
                                    min_count=args.min_count,
                                    caseless=caseless)
    target_map = ddi2013.target_map

    # get class weights
    _, train_targets = utils.build_corpus(train_corpus, feature_map,
                                          target_map, caseless)
    class_weights = torch.Tensor(
        utils.get_class_weights(train_targets)) if args.class_weight else None

    # load dataets
    _, _, test_loader = utils.load_datasets(args.processed_dir,
                                            args.train_size,
                                            args,
                                            feature_map,
                                            dataloader=True)

    # build model
    vocab_size = len(feature_map)
    tagset_size = len(target_map)
    model = RelationTreeModel(vocab_size, tagset_size, args)

    # loss
    criterion = utils.build_loss(args, class_weights=class_weights)

    # load states
    assert os.path.isfile(model_path), "Checkpoint not found!"
    print('Loading checkpoint file from {}...'.format(model_path))
    checkpoint_file = torch.load(model_path)
    model.load_state_dict(checkpoint_file['state_dict'])

    # trainer
    trainer = TreeTrainer(args, model, criterion)

    # predict
    y_true, y_pred, treelists, f1_by_len = predict(trainer,
                                                   test_loader,
                                                   target_map,
                                                   cuda=args.cuda)

    # assign words to roots
    for tup, treelist in zip(test_raw_corpus, treelists):
        for t in treelist:
            t.idx = tup.sent[t.idx] if t.idx < len(tup.sent) else None

    # prediction
    print('Predicting...')
    # write result: sent_id|e1|e2|ddi|type
    with open(args.predict_file, 'w') as f:
        for tup, pred in zip(test_raw_corpus, y_pred):
            ddi = 0 if pred == 'null' else 1
            f.write('|'.join([tup.sent_id, tup.e1, tup.e2, str(ddi), pred]))
            f.write('\n')

    def print_info(f, tup, target, pred, root):
        f.write('{}\n'.format(' '.join(tup.sent)))
        f.write('{}\n'.format(' | '.join(
            [tup.sent_id, tup.e1, tup.e2, target, pred])))
        f.write('{}\n\n'.format(root))

    # error analysis
    print('Analyzing...')
    with open(args.error_file, 'w') as f:
        f.write(' | '.join(['sent_id', 'e1', 'e2', 'target', 'pred']))
        f.write('\n')
        for tup, target, pred, treelist in zip(test_raw_corpus, y_true, y_pred,
                                               treelists):
            if target != pred:
                print_info(f, tup, target, pred, treelist[-1])

    # attention
    print('Writing attention scores...')
    with open(args.correct_file, 'w') as f:
        f.write(' | '.join(['target', 'sent', 'att_weight']))
        f.write('\n')
        for tup, target, pred, treelist in zip(test_raw_corpus, y_true, y_pred,
                                               treelists):
            if target == pred and target != 'null':
                print_info(f, tup, target, pred, treelist[-1])