Exemple #1
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')
    embedding, opt = load_meta(opt, os.path.join(args.data_dir, args.meta))
    train_data = BatchGen(os.path.join(args.data_dir, args.train_data),
                          batch_size=args.batch_size,
                          gpu=args.cuda, elmo_on=args.elmo_on)
    dev_data = BatchGen(os.path.join(args.data_dir, args.dev_data),
                          batch_size=args.batch_size_eval,
                          gpu=args.cuda, is_train=False, elmo_on=args.elmo_on)
    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)

    model = DocReaderModel(opt, embedding)
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    best_em_score, best_f1_score = 0.0, 0.0

    print("PROGRESS: 00.00%")
    for epoch in range(0, args.epoches):
        logger.warning('At epoch {}'.format(epoch))
        train_data.reset()
        start = datetime.now()
        for i, batch in enumerate(train_data):
            model.update(batch)
            if (i + 1) % args.log_per_updates == 0 or i == 0:
                logger.info('updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
                    model.updates, model.train_loss.avg,
                    str((datetime.now() - start) / (i + 1) * (len(train_data) - i - 1)).split('.')[0]))
        # dev eval
        em, f1, results = check(model, dev_data, os.path.join(args.data_dir, args.dev_gold))
        output_path = os.path.join(model_dir, 'dev_output_{}.json'.format(epoch))
        with open(output_path, 'w') as f:
            json.dump(results, f)

        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(model_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
        if not args.philly_on:
            model.save(model_file)
        if em + f1 > best_em_score + best_f1_score:
            model.save(os.path.join(model_dir, 'best_checkpoint.pt'))
            # copyfile(model_file, os.path.join(model_dir, 'best_checkpoint.pt'))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')
        logger.warning("Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})".format(epoch, em, f1, best_em_score, best_f1_score))
        print("PROGRESS: {0:.2f}%".format(100.0 * (epoch + 1) / args.epoches))
Exemple #2
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info(opt)
    embedding, opt, vocab = load_meta(opt, args.meta)
    max_doc = opt['max_doc']
    smooth = opt['smooth']
    is_rep = opt['is_rep']
    eval_step = opt['eval_step']
    curve_file = opt['curve_file']

    training_step = 0
    cur_eval_step = 1

    checkpoint_path = args.resume
    if checkpoint_path == '':
        if not args.if_train:
            print('checkpoint path can not be empty during testing...')
            exit()
        model = DocReaderModel(opt, embedding)
    else:
        state_dict = torch.load(checkpoint_path)["state_dict"]
        model = DocReaderModel(opt, embedding, state_dict)
    model.setup_eval_embed(embedding)
    logger.info("Total number of params: {}".format(model.total_param))

    if args.cuda:
        model.cuda()

    pred_output_path = os.path.join(model_dir, 'pred_output')
    if not os.path.exists(pred_output_path):
        os.makedirs(pred_output_path)
    full_output_path = os.path.join(model_dir, 'full_output_path')
    if not os.path.exists(full_output_path):
        os.makedirs(full_output_path)

    if args.if_train:
        logger.info('Loading training data')
        train_data = BatchGen(os.path.join(args.data_dir, args.train_data),
                              batch_size=args.batch_size,
                              gpu=args.cuda,
                              doc_maxlen=max_doc)
        logger.info('Loading dev data')
        dev_data = BatchGen(os.path.join(args.data_dir, args.dev_data),
                            batch_size=8,
                            gpu=args.cuda,
                            is_train=False,
                            doc_maxlen=max_doc)
        curve_file = os.path.join(model_dir, curve_file)
        full_path = os.path.join(args.data_dir, args.dev_full)
        pred_output = os.path.join(pred_output_path, str(
            model.updates)) + '.txt'
        full_output = os.path.join(full_output_path, str(
            model.updates)) + '_full.txt'

        for epoch in range(0, args.epoches):
            logger.warning('At epoch {}'.format(epoch))
            train_data.reset()
            start = datetime.now()
            for i, batch in enumerate(train_data):
                training_step += 1
                model.update(batch, smooth, is_rep)
                if (i + 1) % args.log_per_updates == 0:
                    logger.info(
                        'updates[{0:6}] train: loss[{1:.5f}]'
                        ' ppl[{2:.5f}] remaining[{3}]'.format(
                            model.updates, model.train_loss.avg,
                            np.exp(model.train_loss.avg),
                            str((datetime.now() - start) / (i + 1) *
                                (len(train_data) - i - 1)).split('.')[0]))

                    # setting up scheduler
                    if model.scheduler is not None:
                        if opt['scheduler_type'] == 'rop':
                            model.scheduler.step(model.train_loss.avg,
                                                 epoch=epoch)
                        else:
                            model.scheduler.step()

                dev_loss = 0.0
                if (training_step) == cur_eval_step:
                    print('evaluating_step is {} ....'.format(training_step))
                    bleu, bleu_fact, diver_uni, diver_bi = check(
                        model, dev_data, vocab, full_path, pred_output,
                        full_output)
                    dev_loss = eval_test_loss(model,
                                              dev_data).data.cpu().numpy()[0]
                    # dev_loss = dev_loss.data.cpu().numpy()[0]
                    logger.info(
                        'updates[{0:6}] train: loss[{1:.5f}] ppl[{2:.5f}]\n'
                        'dev: loss[{3:.5f}] ppl[{4:.5f}]'.format(
                            model.updates, model.train_loss.avg,
                            np.exp(model.train_loss.avg), dev_loss,
                            np.exp(dev_loss)))
                    print('{0},{1:.5f},{2:.5f},{3:.5f},{4:.5f},'
                          '{5:.5f},{6:.5f},{7:.5f},{8:.5f}\n'.format(
                              model.updates, model.train_loss.avg,
                              np.exp(model.train_loss.avg), dev_loss,
                              np.exp(dev_loss), float(bleu), float(diver_uni),
                              float(diver_bi), float(bleu_fact)))
                    with open(curve_file, 'a+') as fout_dev:
                        fout_dev.write(
                            '{0},{1:.5f},{2:.5f},{3:.5f},{4:.5f},'
                            '{5:.5f},{6:.5f},{7:.5f},{8:.5f}\n'.format(
                                model.updates, model.train_loss.avg,
                                np.exp(model.train_loss.avg), dev_loss,
                                np.exp(dev_loss), float(bleu),
                                float(diver_uni), float(diver_bi),
                                float(bleu_fact)))

                    if cur_eval_step == 1:
                        cur_eval_step = cur_eval_step - 1
                    cur_eval_step += eval_step

                if (i + 1) % (args.log_per_updates * 50) == 0:
                    logger.info(
                        'have saved model as checkpoint_step_{0}_{1:.5f}.pt'.
                        format(model.updates, np.exp(dev_loss)))
                    model_file = os.path.join(
                        model_dir, 'checkpoint_step_{0}_{1:.5f}.pt'.format(
                            model.updates, np.exp(dev_loss)))
                    model.save(model_file, epoch)

            #save
            dev_loss = eval_test_loss(model, dev_data)
            dev_loss = dev_loss.data.cpu().numpy()[0]
            logger.info(
                'have saved model as checkpoint_epoch_{0}_{1}_{2:.5f}.pt'.
                format(epoch, args.learning_rate, np.exp(dev_loss)))
            model_file = os.path.join(
                model_dir, 'checkpoint_epoch_{0}_{1}_{2:.5f}.pt'.format(
                    epoch, args.learning_rate, np.exp(dev_loss)))
            model.save(model_file, epoch)

    else:
        logger.info('Loading evaluation data')
        checkpoint_path = args.resume
        state_dict = torch.load(checkpoint_path)["state_dict"]
        model = DocReaderModel(opt, embedding, state_dict)
        model.setup_eval_embed(embedding)
        logger.info("Total number of params: {}".format(model.total_param))
        if args.cuda:
            model.cuda()

        def _eval_output(file_path=args.dev_data,
                         full_path=args.dev_full,
                         test_type='dev'):
            data = BatchGen(os.path.join(args.data_dir, file_path),
                            batch_size=args.batch_size,
                            gpu=args.cuda,
                            is_train=False)
            print(len(data))
            full_path = os.path.join(args.data_dir, full_path)
            pred_output_path = os.path.join('./output/', test_type) + '/'
            full_output_path = os.path.join('./full_output/', test_type) + '/'
            if not os.path.exists(pred_output_path):
                os.makedirs(pred_output_path)
            if not os.path.exists(full_output_path):
                os.makedirs(full_output_path)
            t = args.test_output
            pred_output = pred_output_path + t + '.txt'
            full_output = full_output_path + t + '_full.txt'
            bleu, bleu_fact, diver_uni, diver_bi = \
            check(model, data, vocab, full_path, pred_output, full_output)
            _loss = eval_test_loss(model, data)
            _loss = _loss.data.cpu().numpy()[0]
            logger.info('dev loss[{0:.5f}] ppl[{1:.5f}]'.format(
                _loss, np.exp(_loss)))
            print('{0},{1:.5f},{2:.5f},{3:.5f},{4:.5f},{5:.5f},'
                  '{6:.5f},{7:.5f},{8:.5f}\n'.format(
                      model.updates, model.train_loss.avg,
                      np.exp(model.train_loss.avg), _loss, np.exp(_loss),
                      float(bleu), float(diver_uni), float(diver_bi),
                      float(bleu_fact)))

        print('test result is:')
        _eval_output(args.test_data, args.test_full, 'test')
Exemple #3
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')
    version = 'v1'
    gold_version = 'v1.1'

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    test_path = gen_name(args.data_dir, args.test_data, version)
    test_gold_path = gen_gold_name(args.data_dir, args.test_gold, gold_version)

    if args.v2_on:
        version = 'v2'
        gold_version = 'v2.0'
        dev_labels = load_squad_v2_label(args.dev_gold)

    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    # train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
    #                       batch_size=args.batch_size,
    #                       gpu=args.cuda,
    #                       with_label=args.v2_on,
    #                       elmo_on=args.elmo_on)
    # import pdb; pdb.set_trace()
    dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    test_data = None
    test_gold = None

    # if os.path.exists(test_path):
    #     test_data = BatchGen(test_path,
    #                         batch_size=args.batch_size,
    #                         gpu=args.cuda, is_train=False, elmo_on=args.elmo_on)

    # load golden standard
    dev_gold = load_squad(dev_gold_path)

    if os.path.exists(test_gold_path):
        test_gold = load_squad(test_gold_path)

    model = DocReaderModel(opt, embedding)  ### model = your_model()

    # model meta str
    headline = '############# Model Arch of SAN #############'
    # print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    best_em_score, best_f1_score = 0.0, 0.0

    # test epoch value..
    epoch = 2
    # for epoch in range(0, args.epoches):
    #     logger.warning('At epoch {}'.format(epoch))
    #     train_data.reset()
    #     start = datetime.now()
    #     for i, batch in enumerate(train_data):
    #         model.update(batch)
    #         if (model.updates) % args.log_per_updates == 0 or i == 0:
    #             logger.info('#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
    #                 model.updates, model.train_loss.avg,
    #                 str((datetime.now() - start) / (i + 1) * (len(train_data) - i - 1)).split('.')[0]))
    # dev eval
    # load the best model from disk...
    # import pdb;pdb.set_trace()
    f'loading the model from disk........'
    path1 = '/demo-mount/san_mrc/checkpoint/checkpoint_v1_epoch_0_full_model.pt'
    # model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    # checkpoint_test = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    model = torch.load(path1)

    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold,
                             results,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels)
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']

    output_path = os.path.join(model_dir, 'dev_output_{}.json'.format(epoch))
    with open(output_path, 'w') as f:
        json.dump(results, f)

        if test_data is not None:
            test_results, test_labels = predict_squad(model,
                                                      test_data,
                                                      v2_on=args.v2_on)
            test_output_path = os.path.join(
                model_dir, 'test_output_{}.json'.format(epoch))
            with open(test_output_path, 'w') as f:
                json.dump(test_results, f)

            if (test_gold is not None):
                if args.v2_on:
                    test_metric = evaluate_v2(
                        test_gold,
                        test_results,
                        na_prob_thresh=args.classifier_threshold)
                    test_em, test_f1 = test_metric['exact'], test_metric['f1']
                    test_acc = compute_acc(labels, test_labels)
                else:
                    test_metric = evaluate(test_gold, test_results)
                    test_em, test_f1 = test_metric['exact_match'], test_metric[
                        'f1']

        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(model_dir,
                             'best_{}_checkpoint.pt'.format(version)))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        logger.warning(
            "Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})"
            .format(epoch, em, f1, best_em_score, best_f1_score))
        if args.v2_on:
            logger.warning("Epoch {0} - ACC: {1:.4f}".format(epoch, acc))
        if metric is not None:
            logger.warning("Detailed Metric at Epoch {0}: {1}".format(
                epoch, metric))

        if (test_data is not None) and (test_gold is not None):
            logger.warning("Epoch {0} - test EM: {1:.3f} F1: {2:.3f}".format(
                epoch, test_em, test_f1))
            if args.v2_on:
                logger.warning("Epoch {0} - test ACC: {1:.4f}".format(
                    epoch, test_acc))
Exemple #4
0
def main():
	logger.info('Launching the SAN')
	opt = vars(args)
	logger.info('Loading data')
	version = 'v1'
	if args.v2_on:
		version = 'v2'
		dev_labels = load_squad_v2_label(args.dev_gold)

	embedding, opt = load_meta(opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
	train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
						  batch_size=args.batch_size,
						  gpu=args.cuda,
						  with_label=args.v2_on)
	dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
						  batch_size=args.batch_size,
						  gpu=args.cuda, is_train=False)

	# load golden standard
	dev_gold = load_squad(args.dev_gold)

	model = DocReaderModel(opt, embedding)
	# model meta str
	headline = '############# Model Arch of SAN #############'
	# print network
	logger.info('\n{}\n{}\n'.format(headline, model.network))
	model.setup_eval_embed(embedding)

	logger.info("Total number of params: {}".format(model.total_param))
	if args.cuda:
		model.cuda()

	best_em_score, best_f1_score = 0.0, 0.0

	for epoch in range(0, args.epoches):
		logger.warning('At epoch {}'.format(epoch))
		train_data.reset()
		start = datetime.now()
		for i, batch in enumerate(train_data):
			#pdb.set_trace()
			model.update(batch)
			if (model.updates) % args.log_per_updates == 0 or i == 0:
				logger.info('#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
					model.updates, model.train_loss.avg,
					str((datetime.now() - start) / (i + 1) * (len(train_data) - i - 1)).split('.')[0]))
		# dev eval
		results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
		if args.v2_on:
			metric = evaluate_v2(dev_gold, results, na_prob_thresh=args.classifier_threshold)
			em, f1 = metric['exact'], metric['f1']
			acc = compute_acc(labels, dev_labels)
			cls_pr, cls_rec, cls_f1 = compute_classifier_pr_rec(labels, dev_labels)
		else:
			metric = evaluate(dev_gold, results)
			em, f1 = metric['exact_match'], metric['f1']

		output_path = os.path.join(model_dir, 'dev_output_{}.json'.format(epoch))
		with open(output_path, 'w') as f:
			json.dump(results, f)

		# setting up scheduler
		if model.scheduler is not None:
			logger.info('scheduler_type {}'.format(opt['scheduler_type']))
			if opt['scheduler_type'] == 'rop':
				model.scheduler.step(f1, epoch=epoch)
			else:
				model.scheduler.step()
		# save
		model_file = os.path.join(model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

		model.save(model_file, epoch)
		if em + f1 > best_em_score + best_f1_score:
			copyfile(os.path.join(model_dir, model_file), os.path.join(model_dir, 'best_{}_checkpoint.pt'.format(version)))
			best_em_score, best_f1_score = em, f1
			logger.info('Saved the new best model and prediction')

		logger.warning("Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})".format(epoch, em, f1, best_em_score, best_f1_score))
		if args.v2_on:
			logger.warning("Epoch {0} - Precision: {1:.4f}, Recall: {2:.4f}, F1: {3:.4f}, Accuracy: {4:.4f}".format(epoch, cls_pr, cls_rec, cls_f1, acc))
		if metric is not None:
			logger.warning("Detailed Metric at Epoch {0}: {1}".format(epoch, metric))
Exemple #5
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')

    version = 'v2' if args.v2_on else 'v1'
    gold_version = 'v2.0' if args.v2_on else 'v1.1'


    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)
    dev_labels = load_squad_v2_label(dev_gold_path)

    embedding, opt = load_meta(opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    dev_data = BatchGen(dev_path,
                          batch_size=args.batch_size,
                          gpu=args.cuda, is_train=False, elmo_on=args.elmo_on)

    dev_gold = load_squad(dev_gold_path)

    checkpoint_path = args.checkpoint_path
    logger.info(f'path to given checkpoint is {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path) if args.cuda else torch.load(checkpoint_path, map_location='cpu')
    state_dict = checkpoint['state_dict']

    # Set up the model
    logger.info('Loading model ...')
    model = DocReaderModel(opt, embedding,state_dict)
    model.setup_eval_embed(embedding)
    logger.info('done')

    if args.cuda:
        model.cuda()

    # dev eval
    logger.info('Predicting ...')
    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    logger.info('done')

    # get actual and predicted labels (as lists)
    actual_labels = []
    predicted_labels = []
    dropped = 0
    for key in dev_labels.keys(): # convert from dictionaries to lists
        try:
            actual_labels.append(dev_labels[key])
            predicted_labels.append(labels[key])
        except:
            dropped += 1
    print(f'dropped: {dropped}')

    actual_labels = np.array(actual_labels)
    predicted_labels = np.array(predicted_labels)

    # convert from continuous to discrete
    actual_labels = (actual_labels > args.classifier_threshold).astype(np.int32)
    predicted_labels = (predicted_labels > args.classifier_threshold).astype(np.int32)

    # Print all metrics
    print('accuracy', 100 - 100 * sum(abs(predicted_labels-actual_labels)) / len(actual_labels), '%')
    print('confusion matrix', confusion_matrix(predicted_labels, actual_labels))
    precision, recall, f1, _ = precision_recall_fscore_support(actual_labels, predicted_labels, average='binary')
    print(f'Precision: {precision} recall: {recall} f1-score: {f1}')
Exemple #6
0
def main():

    opt = vars(args)
    logger.info('Loading Squad')
    version = 'v1'
    gold_version = 'v1.1'

    if args.v2_on:
        version = 'v2'
        gold_version = 'v2.0'
        dev_labels = load_squad_v2_label(args.dev_gold)

    logger.info('Loading Meta')
    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))

    logger.info('Loading Train Batcher')

    if args.elmo_on:
        logger.info('ELMO ON')

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
                          batch_size=args.batch_size,
                          gpu=args.cuda,
                          with_label=args.v2_on,
                          elmo_on=args.elmo_on)

    logger.info('Loading Dev Batcher')
    dev_data = BatchGen(dev_path,
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    logger.info('Loading Golden Standards')
    # load golden standard
    dev_gold = load_squad(args.dev_gold)

    if len(args.resume) > 0:
        logger.info('Loading resumed model')
        model = DocReaderModel.load(args.resume, embedding, gpu=args.cuda)
        resumeSplit = args.resume.split('_')

        best_f1_score = float(resumeSplit[6].replace('.pt', ''))
        best_em_score = float(resumeSplit[4])
        resumed_epoch = int(resumeSplit[2]) + 1

        #step scheduler
        for i in range(resumed_epoch):
            model.scheduler.step()

        logger.info(
            "RESUMING MODEL TRAINING. BEST epoch {} EM {} F1 {} ".format(
                str(resumed_epoch), str(best_em_score), str(best_f1_score)))

    else:
        model = DocReaderModel(opt, embedding)
        best_em_score, best_f1_score = 0.0, 0.0
        resumed_epoch = 0

    # model meta str
    # headline = '############# Model Arch of SAN #############'
    # print network
    # logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    logger.info("Batch Size {}".format(args.batch_size))
    if args.cuda:
        model.cuda()
    else:
        model.cpu()

    for epoch in range(resumed_epoch, args.epoches):
        logger.warning('At epoch {}'.format(epoch))

        #shuffle training batch
        train_data.reset()
        start = datetime.now()
        for i, batch in enumerate(train_data):
            model.update(batch)
            if (model.updates) % args.log_per_updates == 0 or i == 0:
                logger.info(
                    '#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.
                    format(
                        model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(train_data) - i - 1)).split('.')[0]))
        # dev eval
        results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
        if args.v2_on:
            metric = evaluate_v2(dev_gold,
                                 results,
                                 na_prob_thresh=args.classifier_threshold)
            em, f1 = metric['exact'], metric['f1']
            acc = compute_acc(labels, dev_labels)
        else:
            metric = evaluate(dev_gold, results)
            em, f1 = metric['exact_match'], metric['f1']

        output_path = os.path.join(model_dir,
                                   'dev_output_{}.json'.format(epoch))
        with open(output_path, 'w') as f:
            json.dump(results, f)

        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir,
            'cp_epoch_{}_em_{}_f1_{}.pt'.format(epoch, int(em), int(f1)))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(
                    model_dir, 'best_epoch_{}_em_{}_f1_{}.pt'.format(
                        epoch, int(em), int(f1))))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        logger.warning(
            "Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})"
            .format(epoch, em, f1, best_em_score, best_f1_score))
        if args.v2_on:
            logger.warning("Epoch {0} - ACC: {1:.4f}".format(epoch, acc))
        if metric is not None:
            logger.warning("Detailed Metric at Epoch {0}: {1}".format(
                epoch, metric))
Exemple #7
0
def main():
    logger.info('Launching the SAN')
    start_test = False
    dev_name = 'dev'

    opt = vars(args)
    logger.info('Loading data')
    embedding, opt = load_meta(
        opt, os.path.join(args.multitask_data_path, args.meta))
    gold_data = load_gold(args.dev_datasets, args.data_dir, dev_name=dev_name)
    best_em_score, best_f1_score = 0.0, 0.0

    all_train_batchgen = []
    all_dev_batchgen = []
    all_train_iters = []
    for dataset_name in args.train_datasets:
        path = os.path.join(args.multitask_data_path,
                            dataset_name + '_train.json')
        this_extra_score = extra_score.get(dataset_name, None)

        all_train_batchgen.append(
            BatchGen(path,
                     batch_size=args.batch_size,
                     gpu=args.cuda,
                     dataset_name=dataset_name,
                     doc_maxlen=args.doc_maxlen,
                     drop_less=args.drop_less,
                     num_gpu=args.num_gpu,
                     dropout_w=args.dropout_w,
                     dw_type=args.dw_type,
                     extra_score=this_extra_score,
                     extra_score_cap=args.extra_score_cap))
    all_train_iters = [iter(item) for item in all_train_batchgen]
    for dataset_name in args.dev_datasets:
        path = os.path.join(args.multitask_data_path,
                            dataset_name + '_%s.json' % dev_name)
        all_dev_batchgen.append(
            BatchGen(path,
                     batch_size=args.valid_batch_size,
                     gpu=args.cuda,
                     is_train=False,
                     dataset_name=dataset_name,
                     doc_maxlen=args.doc_maxlen,
                     num_gpu=args.num_gpu))
        if 'marco' in dataset_name:
            rank_path = os.path.join(args.data_dir, dataset_name)
            dev_rank_path = os.path.join(rank_path, 'dev_rank_scores.json')
            dev_rank_scores = load_rank_score(dev_rank_path)
            dev_yn = json.load(
                open(os.path.join(rank_path, 'dev_yn_dict.json')))
            dev_gold_path = os.path.join(args.data_dir, dataset_name,
                                         'dev_original.json')
            dev_gold_data_marco = load_jsonl(dev_gold_path)
    if args.resume_last_epoch:
        latest_time = 0
        for o in os.listdir(model_dir):
            if o.startswith('checkpoint_') and 'trim' not in o:
                edit_time = os.path.getmtime(os.path.join(model_dir, o))
                if edit_time > latest_time:
                    latest_time = edit_time
                    args.resume_dir = model_dir
                    args.resume = o

    if args.resume_dir is not None:
        print('resuming model in ', os.path.join(args.resume_dir, args.resume))
        checkpoint = torch.load(os.path.join(args.resume_dir, args.resume))

        model_opt = checkpoint['config'] if args.resume_options else opt
        model_opt['multitask_data_path'] = opt['multitask_data_path']
        model_opt['covec_path'] = opt['covec_path']
        model_opt['data_dir'] = opt['data_dir']

        if args.resume_options:
            logger.info('resume old options')
        else:
            logger.info('use new options.')
        model_opt['train_datasets'] = checkpoint['config']['train_datasets']

        state_dict = checkpoint['state_dict']
        model = DocReaderModel(model_opt, embedding, state_dict)

        if not args.new_random_state:
            logger.info('use old random state.')
            random.setstate(checkpoint['random_state'])
            torch.random.set_rng_state(checkpoint['torch_state'])
            if args.cuda:
                torch.cuda.set_rng_state(checkpoint['torch_cuda_state'])

        if model.scheduler:
            if args.new_scheduler:
                model.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    model.optimizer, milestones=[2, 5, 8], gamma=args.lr_gamma)
            elif 'scheduler_state' in checkpoint:
                model.scheduler.load_state_dict(checkpoint['scheduler_state'])
            else:
                print(
                    'warning: not loading scheduler state because didn\'t save.'
                )
        start_epoch = checkpoint['epoch'] + 1
    else:
        model = DocReaderModel(opt, embedding)
        start_epoch = 0
    logger.info('using {} GPUs'.format(torch.cuda.device_count()))
    headline = '############# Model Arch of SAN #############'
    # print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    all_lens = [len(bg) for bg in all_train_batchgen]

    if args.continue_epoches is not None:
        args.epoches = start_epoch + args.continue_epoches
    num_all_batches = args.epoches * sum(all_lens)
    best_performance = {name: 0.0 for name in args.dev_datasets}
    best_performance['total'] = 0.0

    for epoch in range(start_epoch, args.epoches):
        logger.warning('At epoch {}'.format(epoch))

        # batch indices
        all_call_indices = []
        for train_data in all_train_batchgen:
            train_data.reset()
        if args.dataset_include_ratio >= 0:
            other_indices = []
            for i in range(1, len(all_train_batchgen)):
                other_indices += [i] * len(all_train_batchgen[i])
            if args.dataset_include_ratio > 1:
                inverse_ratio = 1 / args.dataset_include_ratio
                batch0_indices = [0] * (int(
                    len(other_indices) * inverse_ratio))
            else:
                batch0_indices = [0] * len(all_train_batchgen[0])
                other_picks = int(
                    len(other_indices) * args.dataset_include_ratio)
                other_indices = random.sample(other_indices, other_picks)
            all_call_indices = batch0_indices + other_indices
        else:
            for i in range(len(all_train_batchgen)):
                all_call_indices += [i] * len(all_train_batchgen[i])
        random.shuffle(all_call_indices)
        all_call_indices = all_call_indices[:10]
        start = datetime.now()
        for i in range(len(all_call_indices)):

            batch_list, name_map = next(all_train_iters[all_call_indices[i]])
            dataset_name = args.train_datasets[all_call_indices[i]]

            model.update(batch_list, name_map, dataset_name)
            if (model.updates) % args.log_per_updates == 0 or i == 0:
                logger.info(
                    'o(*^~^*) Task [{0:2}] #updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'
                    .format(
                        all_call_indices[i], model.updates,
                        model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(all_call_indices) - i - 1)).split('.')[0]))

        em_sum = 0
        f1_sum = 0
        model.eval()
        this_performance = {}
        for i in range(len(all_dev_batchgen)):
            dataset_name = args.dev_datasets[i]
            if dataset_name in ['squad', 'newsqa']:
                em, f1, results, scores = check(model, all_dev_batchgen[i],
                                                gold_data[dataset_name])
                output_path = os.path.join(
                    model_dir,
                    'dev_output_{}_{}.json'.format(dataset_name, epoch))
                output_scores_path = os.path.join(
                    model_dir,
                    'dev_scores_{}_{}.pt'.format(dataset_name, epoch))
                for repeat_times in range(10):
                    try:
                        with open(output_path, 'w') as f:
                            json.dump(results, f)
                        with open(output_scores_path, 'wb') as f:
                            pickle.dump(scores, f)
                        break
                    except Exception as e:
                        print('save predict failed. error:', e)
                em_sum += em
                f1_sum += f1
                this_performance[dataset_name] = em + f1
                logger.warning(
                    "Epoch {0} - Task {1:6} dev EM: {2:.3f} F1: {3:.3f}".
                    format(epoch, dataset_name, em, f1))
            elif dataset_name == 'wdw':
                acc, results, scores = check_wdw(model, all_dev_batchgen[i])
                output_path = os.path.join(
                    model_dir,
                    'dev_output_{}_{}.json'.format(dataset_name, epoch))
                output_scores_path = os.path.join(
                    model_dir,
                    'dev_scores_{}_{}.pt'.format(dataset_name, epoch))
                for repeat_times in range(10):
                    try:
                        with open(output_path, 'w') as f:
                            json.dump(results, f)
                        with open(output_scores_path, 'wb') as f:
                            pickle.dump(scores, f)
                        break
                    except Exception as e:
                        print('save predict failed. error:', e)
                em_sum += acc
                f1_sum += acc
                logger.warning(
                    "Epoch {0} - Task {1:6} dev ACC: {2:.3f}".format(
                        epoch, dataset_name, acc))
                this_performance[dataset_name] = acc

            elif 'marco' in dataset_name:
                # dev eval
                output = os.path.join(model_dir,
                                      'dev_pred_{}.json'.format(epoch))
                output_yn = os.path.join(model_dir,
                                         'dev_pred_yn_{}.json'.format(epoch))
                span_output = os.path.join(
                    model_dir, 'dev_pred_span_{}.json'.format(epoch))
                dev_predictions, dev_best_scores, dev_ids_list = eval_model_marco(
                    model, all_dev_batchgen[i])
                answer_list, rank_answer_list, yesno_answer_list = generate_submit(
                    dev_ids_list, dev_best_scores, dev_predictions,
                    dev_rank_scores, dev_yn)

                dev_gold_path = os.path.join(args.data_dir, dataset_name,
                                             'dev_original.json')
                metrics = compute_metrics_from_files(dev_gold_data_marco, \
                                                        rank_answer_list, \
                                                        MAX_BLEU_ORDER)
                rouge_score = metrics['rouge_l']
                blue_score = metrics['bleu_1']
                logger.warning(
                    "Epoch {0} - dev ROUGE-L: {1:.4f} BLEU-1: {2:.4f}".format(
                        epoch, rouge_score, blue_score))

                for metric in sorted(metrics):
                    logger.info('%s: %s' % (metric, metrics[metric]))

                this_performance[dataset_name] = rouge_score + blue_score
        this_performance['total'] = sum([v for v in this_performance.values()])
        model.train()
        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        for try_id in range(10):
            try:
                model_file = os.path.join(
                    model_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
                model.save(model_file, epoch, best_em_score, best_f1_score)
                if em_sum + f1_sum > best_em_score + best_f1_score:
                    copyfile(os.path.join(model_dir, model_file),
                             os.path.join(model_dir, 'best_checkpoint.pt'))
                    best_em_score, best_f1_score = em_sum, f1_sum
                    logger.info('Saved the new best model and prediction')
                break
            except Exception as e:
                print('save model failed: outer step. error=', e)
Exemple #8
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')

    version = 'v2' if args.v2_on else 'v1'
    gold_version = 'v2.0' if args.v2_on else 'v1.1'

    train_path = gen_name(args.data_dir, args.train_data, version)
    train_gold_path = gen_gold_name(args.data_dir, 'train', gold_version)

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    test_path = gen_name(args.data_dir, args.test_data, version)
    test_gold_path = gen_gold_name(args.data_dir, args.test_gold, gold_version)

    train_labels = load_squad_v2_label(train_gold_path)
    dev_labels = load_squad_v2_label(dev_gold_path)
    #train_labels = load_squad_v2_label(train_gold_path)

    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    train_data = BatchGen(train_path,
                          batch_size=args.batch_size,
                          gpu=args.cuda,
                          with_label=args.v2_on,
                          elmo_on=args.elmo_on)
    dev_data = BatchGen(dev_path,
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    test_data = None
    test_gold = None

    if os.path.exists(test_path):
        test_data = BatchGen(test_path,
                             batch_size=args.batch_size,
                             gpu=args.cuda,
                             is_train=False,
                             elmo_on=args.elmo_on)

    # load golden standard
    train_gold = load_squad(train_gold_path)
    dev_gold = load_squad(dev_gold_path)
    #train_gold = load_squad(train_gold_path)

    if os.path.exists(test_gold_path):
        test_gold = load_squad(test_gold_path)

    #define csv path
    csv_head = [
        'epoch', 'train_loss', 'train_loss_san', 'train_loss_class', 'dev_em',
        'dev_f1', 'dev_acc', 'train_em', 'train_f1', 'train_acc'
    ]
    csvfile = 'results_{}.csv'.format(args.classifier_gamma)
    csv_path = os.path.join(args.data_dir, csvfile)
    result_params = []

    #load previous checkpoint
    start_epoch = 0
    state_dict = None

    if (args.load_checkpoint != 0):
        start_epoch = args.load_checkpoint + 1
        checkpoint_file = 'checkpoint_{}_epoch_{}.pt'.format(
            version, args.load_checkpoint)
        checkpoint_path = os.path.join(args.model_dir, checkpoint_file)
        logger.info('path to prev checkpoint is {}'.format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        state_dict = checkpoint['state_dict']
        opt = checkpoint['config']
        #logger.warning('the checkpoint is {}'.format(checkpoint))

        #load previous metrics
        with open(csv_path, 'r') as csvfile:
            csvreader = csv.reader(csvfile)
            dummy = next(csvreader)
            for row in csvreader:
                result_params.append(row)

        logger.info('Previous metrics loaded')

    model = DocReaderModel(opt, embedding, state_dict)
    # model meta str
    #headline = '############# Model Arch of SAN #############'
    # print network
    #logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    best_em_score, best_f1_score = 0.0, 0.0

    for epoch in range(start_epoch, args.epoches):
        logger.warning('At epoch {}'.format(epoch))

        loss, loss_san, loss_class = 0.0, 0.0, 0.0

        train_data.reset()
        start = datetime.now()
        for i, batch in enumerate(train_data):
            losses = model.update(batch)
            loss += losses[0].item()
            loss_san += losses[1].item()
            if losses[2]:
                loss_class += losses[2].item()

            if (model.updates) % args.log_per_updates == 0 or i == 0:
                logger.info(
                    '#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.
                    format(
                        model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(train_data) - i - 1)).split('.')[0]))

        # train eval
        tr_results, tr_labels = predict_squad(model,
                                              train_data,
                                              v2_on=args.v2_on)
        if args.v2_on and args.classifier_on:
            train_metric = evaluate_v2(
                train_gold,
                tr_results,
                na_prob_thresh=args.classifier_threshold)
            train_em, train_f1 = train_metric['exact'], train_metric['f1']
            train_acc = compute_acc(tr_labels, train_labels)
        else:
            train_metric = evaluate(train_gold, tr_results)
            train_em, train_f1 = train_metric['exact_match'], train_metric[
                'f1']
            train_acc = -1

        # dev eval
        results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
        if args.v2_on and args.classifier_on:
            metric = evaluate_v2(dev_gold,
                                 results,
                                 na_prob_thresh=args.classifier_threshold)
            em, f1 = metric['exact'], metric['f1']
            acc = compute_acc(labels, dev_labels)
        else:
            metric = evaluate(dev_gold, results)
            em, f1 = metric['exact_match'], metric['f1']
            acc = -1

        output_path = os.path.join(model_dir,
                                   'dev_output_{}.json'.format(epoch))
        with open(output_path, 'w') as f:
            json.dump(results, f)

        if test_data is not None:
            test_results, test_labels = predict_squad(model,
                                                      test_data,
                                                      v2_on=args.v2_on)
            test_output_path = os.path.join(
                model_dir, 'test_output_{}.json'.format(epoch))
            with open(test_output_path, 'w') as f:
                json.dump(test_results, f)

            if (test_gold is not None):
                if args.v2_on:
                    test_metric = evaluate_v2(
                        test_gold,
                        test_results,
                        na_prob_thresh=args.classifier_threshold)
                    test_em, test_f1 = test_metric['exact'], test_metric['f1']
                    test_acc = compute_acc(
                        labels, test_labels
                    )  #?? should be test_labels,test_gold_labels
                else:
                    test_metric = evaluate(test_gold, test_results)
                    test_em, test_f1 = test_metric['exact_match'], test_metric[
                        'f1']

        # setting up scheduler
        # halves learning rate every 10 epochs
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(model_dir,
                             'best_{}_checkpoint.pt'.format(version)))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        approx = lambda x: round(x, 3)

        logger.warning(f""" Epoch {str(epoch).zfill(2)} ---
        Train | acc: {approx(train_acc)} EM: {approx(train_em)} F1: {approx(train_f1)} loss ({approx(loss)}) = {approx(loss_san)} + {approx(loss_class)}
        Dev   | acc: {approx(acc)} EM: {approx(em)} F1: {approx(f1)}
        --------------------------------
        """)

        #writing in CSV
        result_params.append([
            epoch, loss, loss_san, loss_class, em, f1, acc, train_em, train_f1,
            train_acc
        ])
        logger.info('Writing in {} the values {}'.format(
            csv_path, result_params))
        with open(csv_path, 'w') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(csv_head)
            csvwriter.writerows(result_params)
Exemple #9
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')
    version = 'v1'
    if args.v2_on:
        version = 'v2'
        dev_labels = load_squad_v2_label(args.dev_gold)
        dev_labels_adv = load_squad_v2_label('data/adv-dev-v2.0.json')
    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False)
    dev_data_adv = BatchGen(gen_name(args.data_dir, 'adv_' + args.dev_data,
                                     version),
                            batch_size=args.batch_size,
                            gpu=args.cuda,
                            is_train=False)

    # load golden standard
    dev_gold = load_squad(args.dev_gold)
    dev_gold_adv = load_squad('data/adv-dev-v2.0.json')

    # TODO
    best_checkpoint_path = os.path.join(
        model_dir, 'best_{}_checkpoint.pt'.format(version))
    check = torch.load(best_checkpoint_path)
    model = DocReaderModel(check['config'],
                           embedding,
                           state_dict=check['state_dict'])
    model.setup_eval_embed(embedding)

    if args.cuda:
        model.cuda()

    # dev eval
    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold,
                             results,
                             labels,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels)
        print("Original validation EM {}, F1 {}, Acc {}".format(em, f1, acc))
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']

    results, labels = predict_squad(model, dev_data_adv, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold_adv,
                             results,
                             labels,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels_adv)
        print("Adversarial EM {}, F1 {}, Acc {}".format(em, f1, acc))
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']
Exemple #10
0
                        batch_size,
                        have_gpu,
                        is_train=False,
                        with_label=True)
    #batches.reset()
    #batches = list(batches)

    model_path = model_root + 'best_checkpoint.pt'

    checkpoint = torch.load(model_path)

    opt = checkpoint['config']
    set_environment(opt['seed'], have_gpu)
    opt['covec_path'] = mtlstm_path
    opt['cuda'] = have_gpu
    opt['multi_gpu'] = False
    opt['max_len'] = max_len
    state_dict = checkpoint['state_dict']
    model = DocReaderModel(opt, state_dict=state_dict)
    model.setup_eval_embed(torch.Tensor(test_embedding))
    logger.info('Loaded model!')

    if have_gpu:
        model.cuda()

    results, score_list = evaluate_squad_v2(model, dev_data)

    dev_gold = load_squad_v2(test_file)

    results = my_evaluation(dev_gold, results, score_list, 0.4)
    logger.info('{}'.format(results))
Exemple #11
0
class InteractiveModel:
    def __init__(self, args):
        self.is_cuda = args.cuda
        self.embedding, self.opt, self.vocab = load_meta(vars(args), args.meta)
        self.opt['skip_tokens'] = self.get_skip_tokens(
            self.opt["skip_tokens_file"])
        self.opt['skip_tokens_first'] = self.get_skip_tokens(
            self.opt["skip_tokens_first_file"])
        self.state_dict = th.load(args.model_dir)["state_dict"]
        self.model = DocReaderModel(self.opt, self.embedding, self.state_dict)
        self.model.setup_eval_embed(self.embedding)
        if self.is_cuda:
            self.model.cuda()

    def get_skip_tokens(self, path):
        skip_tokens = None
        if path and os.path.isfile(path):
            skip_tokens = []
            with open(path, 'r') as f:
                for word in f:
                    word = word.strip().rstrip('\n')
                    try:
                        skip_tokens.append(self.vocab[word])
                    except:
                        print("Token %s not present in dictionary" % word)
        return skip_tokens

    def predict(self, data, top_k=2):
        processed_data = prepare_batch_data(
            [self.preprocess_data(x) for x in data], ground_truth=False)
        prediction, prediction_topks = self.model.predict(processed_data,
                                                          top_k=top_k)
        pred_word = pred2words(prediction, self.vocab)
        prediction = [np.asarray(x, dtype=np.str).tolist() for x in pred_word]
        return (prediction, prediction_topks)

    def preprocess_data(self, sample, q_cutoff=30, doc_cutoff=500):
        def tok_func(toks):
            return [self.vocab[w] for w in toks]

        fea_dict = {}

        query_tokend = filter_query(sample['query'].strip(),
                                    max_len=q_cutoff).split()
        doc_tokend = filter_fact(sample['fact'].strip()).split()
        if len(doc_tokend) > doc_cutoff:
            doc_tokend = doc_tokend[:doc_cutoff] + ['<TRNC>']

        # TODO
        fea_dict['query_tok'] = tok_func(query_tokend)
        fea_dict['query_pos'] = []
        fea_dict['query_ner'] = []

        fea_dict['doc_tok'] = tok_func(doc_tokend)
        fea_dict['doc_pos'] = []
        fea_dict['doc_ner'] = []
        fea_dict['doc_fea'] = ''

        if len(fea_dict['query_tok']) == 0:
            fea_dict['query_tok'] = [0]
        if len(fea_dict['doc_tok']) == 0:
            fea_dict['doc_tok'] = [0]

        return fea_dict