Ejemplo n.º 1
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')
    version = 'v1'
    gold_version = 'v1.1'

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    test_path = gen_name(args.data_dir, args.test_data, version)
    test_gold_path = gen_gold_name(args.data_dir, args.test_gold, gold_version)

    if args.v2_on:
        version = 'v2'
        gold_version = 'v2.0'
        dev_labels = load_squad_v2_label(args.dev_gold)

    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    # train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
    #                       batch_size=args.batch_size,
    #                       gpu=args.cuda,
    #                       with_label=args.v2_on,
    #                       elmo_on=args.elmo_on)
    # import pdb; pdb.set_trace()
    dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    test_data = None
    test_gold = None

    # if os.path.exists(test_path):
    #     test_data = BatchGen(test_path,
    #                         batch_size=args.batch_size,
    #                         gpu=args.cuda, is_train=False, elmo_on=args.elmo_on)

    # load golden standard
    dev_gold = load_squad(dev_gold_path)

    if os.path.exists(test_gold_path):
        test_gold = load_squad(test_gold_path)

    model = DocReaderModel(opt, embedding)  ### model = your_model()

    # model meta str
    headline = '############# Model Arch of SAN #############'
    # print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    best_em_score, best_f1_score = 0.0, 0.0

    # test epoch value..
    epoch = 2
    # for epoch in range(0, args.epoches):
    #     logger.warning('At epoch {}'.format(epoch))
    #     train_data.reset()
    #     start = datetime.now()
    #     for i, batch in enumerate(train_data):
    #         model.update(batch)
    #         if (model.updates) % args.log_per_updates == 0 or i == 0:
    #             logger.info('#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
    #                 model.updates, model.train_loss.avg,
    #                 str((datetime.now() - start) / (i + 1) * (len(train_data) - i - 1)).split('.')[0]))
    # dev eval
    # load the best model from disk...
    # import pdb;pdb.set_trace()
    f'loading the model from disk........'
    path1 = '/demo-mount/san_mrc/checkpoint/checkpoint_v1_epoch_0_full_model.pt'
    # model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    # checkpoint_test = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    model = torch.load(path1)

    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold,
                             results,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels)
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']

    output_path = os.path.join(model_dir, 'dev_output_{}.json'.format(epoch))
    with open(output_path, 'w') as f:
        json.dump(results, f)

        if test_data is not None:
            test_results, test_labels = predict_squad(model,
                                                      test_data,
                                                      v2_on=args.v2_on)
            test_output_path = os.path.join(
                model_dir, 'test_output_{}.json'.format(epoch))
            with open(test_output_path, 'w') as f:
                json.dump(test_results, f)

            if (test_gold is not None):
                if args.v2_on:
                    test_metric = evaluate_v2(
                        test_gold,
                        test_results,
                        na_prob_thresh=args.classifier_threshold)
                    test_em, test_f1 = test_metric['exact'], test_metric['f1']
                    test_acc = compute_acc(labels, test_labels)
                else:
                    test_metric = evaluate(test_gold, test_results)
                    test_em, test_f1 = test_metric['exact_match'], test_metric[
                        'f1']

        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(model_dir,
                             'best_{}_checkpoint.pt'.format(version)))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        logger.warning(
            "Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})"
            .format(epoch, em, f1, best_em_score, best_f1_score))
        if args.v2_on:
            logger.warning("Epoch {0} - ACC: {1:.4f}".format(epoch, acc))
        if metric is not None:
            logger.warning("Detailed Metric at Epoch {0}: {1}".format(
                epoch, metric))

        if (test_data is not None) and (test_gold is not None):
            logger.warning("Epoch {0} - test EM: {1:.3f} F1: {2:.3f}".format(
                epoch, test_em, test_f1))
            if args.v2_on:
                logger.warning("Epoch {0} - test ACC: {1:.4f}".format(
                    epoch, test_acc))
Ejemplo n.º 2
0
def main():
	logger.info('Launching the SAN')
	opt = vars(args)
	logger.info('Loading data')
	version = 'v1'
	if args.v2_on:
		version = 'v2'
		dev_labels = load_squad_v2_label(args.dev_gold)

	embedding, opt = load_meta(opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
	train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
						  batch_size=args.batch_size,
						  gpu=args.cuda,
						  with_label=args.v2_on)
	dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
						  batch_size=args.batch_size,
						  gpu=args.cuda, is_train=False)

	# load golden standard
	dev_gold = load_squad(args.dev_gold)

	model = DocReaderModel(opt, embedding)
	# model meta str
	headline = '############# Model Arch of SAN #############'
	# print network
	logger.info('\n{}\n{}\n'.format(headline, model.network))
	model.setup_eval_embed(embedding)

	logger.info("Total number of params: {}".format(model.total_param))
	if args.cuda:
		model.cuda()

	best_em_score, best_f1_score = 0.0, 0.0

	for epoch in range(0, args.epoches):
		logger.warning('At epoch {}'.format(epoch))
		train_data.reset()
		start = datetime.now()
		for i, batch in enumerate(train_data):
			#pdb.set_trace()
			model.update(batch)
			if (model.updates) % args.log_per_updates == 0 or i == 0:
				logger.info('#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
					model.updates, model.train_loss.avg,
					str((datetime.now() - start) / (i + 1) * (len(train_data) - i - 1)).split('.')[0]))
		# dev eval
		results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
		if args.v2_on:
			metric = evaluate_v2(dev_gold, results, na_prob_thresh=args.classifier_threshold)
			em, f1 = metric['exact'], metric['f1']
			acc = compute_acc(labels, dev_labels)
			cls_pr, cls_rec, cls_f1 = compute_classifier_pr_rec(labels, dev_labels)
		else:
			metric = evaluate(dev_gold, results)
			em, f1 = metric['exact_match'], metric['f1']

		output_path = os.path.join(model_dir, 'dev_output_{}.json'.format(epoch))
		with open(output_path, 'w') as f:
			json.dump(results, f)

		# setting up scheduler
		if model.scheduler is not None:
			logger.info('scheduler_type {}'.format(opt['scheduler_type']))
			if opt['scheduler_type'] == 'rop':
				model.scheduler.step(f1, epoch=epoch)
			else:
				model.scheduler.step()
		# save
		model_file = os.path.join(model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

		model.save(model_file, epoch)
		if em + f1 > best_em_score + best_f1_score:
			copyfile(os.path.join(model_dir, model_file), os.path.join(model_dir, 'best_{}_checkpoint.pt'.format(version)))
			best_em_score, best_f1_score = em, f1
			logger.info('Saved the new best model and prediction')

		logger.warning("Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})".format(epoch, em, f1, best_em_score, best_f1_score))
		if args.v2_on:
			logger.warning("Epoch {0} - Precision: {1:.4f}, Recall: {2:.4f}, F1: {3:.4f}, Accuracy: {4:.4f}".format(epoch, cls_pr, cls_rec, cls_f1, acc))
		if metric is not None:
			logger.warning("Detailed Metric at Epoch {0}: {1}".format(epoch, metric))
Ejemplo n.º 3
0
def main():

    opt = vars(args)
    logger.info('Loading Squad')
    version = 'v1'
    gold_version = 'v1.1'

    if args.v2_on:
        version = 'v2'
        gold_version = 'v2.0'
        dev_labels = load_squad_v2_label(args.dev_gold)

    logger.info('Loading Meta')
    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))

    logger.info('Loading Train Batcher')

    if args.elmo_on:
        logger.info('ELMO ON')

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    train_data = BatchGen(gen_name(args.data_dir, args.train_data, version),
                          batch_size=args.batch_size,
                          gpu=args.cuda,
                          with_label=args.v2_on,
                          elmo_on=args.elmo_on)

    logger.info('Loading Dev Batcher')
    dev_data = BatchGen(dev_path,
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    logger.info('Loading Golden Standards')
    # load golden standard
    dev_gold = load_squad(args.dev_gold)

    if len(args.resume) > 0:
        logger.info('Loading resumed model')
        model = DocReaderModel.load(args.resume, embedding, gpu=args.cuda)
        resumeSplit = args.resume.split('_')

        best_f1_score = float(resumeSplit[6].replace('.pt', ''))
        best_em_score = float(resumeSplit[4])
        resumed_epoch = int(resumeSplit[2]) + 1

        #step scheduler
        for i in range(resumed_epoch):
            model.scheduler.step()

        logger.info(
            "RESUMING MODEL TRAINING. BEST epoch {} EM {} F1 {} ".format(
                str(resumed_epoch), str(best_em_score), str(best_f1_score)))

    else:
        model = DocReaderModel(opt, embedding)
        best_em_score, best_f1_score = 0.0, 0.0
        resumed_epoch = 0

    # model meta str
    # headline = '############# Model Arch of SAN #############'
    # print network
    # logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    logger.info("Batch Size {}".format(args.batch_size))
    if args.cuda:
        model.cuda()
    else:
        model.cpu()

    for epoch in range(resumed_epoch, args.epoches):
        logger.warning('At epoch {}'.format(epoch))

        #shuffle training batch
        train_data.reset()
        start = datetime.now()
        for i, batch in enumerate(train_data):
            model.update(batch)
            if (model.updates) % args.log_per_updates == 0 or i == 0:
                logger.info(
                    '#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.
                    format(
                        model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(train_data) - i - 1)).split('.')[0]))
        # dev eval
        results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
        if args.v2_on:
            metric = evaluate_v2(dev_gold,
                                 results,
                                 na_prob_thresh=args.classifier_threshold)
            em, f1 = metric['exact'], metric['f1']
            acc = compute_acc(labels, dev_labels)
        else:
            metric = evaluate(dev_gold, results)
            em, f1 = metric['exact_match'], metric['f1']

        output_path = os.path.join(model_dir,
                                   'dev_output_{}.json'.format(epoch))
        with open(output_path, 'w') as f:
            json.dump(results, f)

        # setting up scheduler
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir,
            'cp_epoch_{}_em_{}_f1_{}.pt'.format(epoch, int(em), int(f1)))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(
                    model_dir, 'best_epoch_{}_em_{}_f1_{}.pt'.format(
                        epoch, int(em), int(f1))))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        logger.warning(
            "Epoch {0} - dev EM: {1:.3f} F1: {2:.3f} (best EM: {3:.3f} F1: {4:.3f})"
            .format(epoch, em, f1, best_em_score, best_f1_score))
        if args.v2_on:
            logger.warning("Epoch {0} - ACC: {1:.4f}".format(epoch, acc))
        if metric is not None:
            logger.warning("Detailed Metric at Epoch {0}: {1}".format(
                epoch, metric))
Ejemplo n.º 4
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')

    version = 'v2' if args.v2_on else 'v1'
    gold_version = 'v2.0' if args.v2_on else 'v1.1'


    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)
    dev_labels = load_squad_v2_label(dev_gold_path)

    embedding, opt = load_meta(opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    dev_data = BatchGen(dev_path,
                          batch_size=args.batch_size,
                          gpu=args.cuda, is_train=False, elmo_on=args.elmo_on)

    dev_gold = load_squad(dev_gold_path)

    checkpoint_path = args.checkpoint_path
    logger.info(f'path to given checkpoint is {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path) if args.cuda else torch.load(checkpoint_path, map_location='cpu')
    state_dict = checkpoint['state_dict']

    # Set up the model
    logger.info('Loading model ...')
    model = DocReaderModel(opt, embedding,state_dict)
    model.setup_eval_embed(embedding)
    logger.info('done')

    if args.cuda:
        model.cuda()

    # dev eval
    logger.info('Predicting ...')
    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    logger.info('done')

    # get actual and predicted labels (as lists)
    actual_labels = []
    predicted_labels = []
    dropped = 0
    for key in dev_labels.keys(): # convert from dictionaries to lists
        try:
            actual_labels.append(dev_labels[key])
            predicted_labels.append(labels[key])
        except:
            dropped += 1
    print(f'dropped: {dropped}')

    actual_labels = np.array(actual_labels)
    predicted_labels = np.array(predicted_labels)

    # convert from continuous to discrete
    actual_labels = (actual_labels > args.classifier_threshold).astype(np.int32)
    predicted_labels = (predicted_labels > args.classifier_threshold).astype(np.int32)

    # Print all metrics
    print('accuracy', 100 - 100 * sum(abs(predicted_labels-actual_labels)) / len(actual_labels), '%')
    print('confusion matrix', confusion_matrix(predicted_labels, actual_labels))
    precision, recall, f1, _ = precision_recall_fscore_support(actual_labels, predicted_labels, average='binary')
    print(f'Precision: {precision} recall: {recall} f1-score: {f1}')
Ejemplo n.º 5
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')

    version = 'v2' if args.v2_on else 'v1'
    gold_version = 'v2.0' if args.v2_on else 'v1.1'

    train_path = gen_name(args.data_dir, args.train_data, version)
    train_gold_path = gen_gold_name(args.data_dir, 'train', gold_version)

    dev_path = gen_name(args.data_dir, args.dev_data, version)
    dev_gold_path = gen_gold_name(args.data_dir, args.dev_gold, gold_version)

    test_path = gen_name(args.data_dir, args.test_data, version)
    test_gold_path = gen_gold_name(args.data_dir, args.test_gold, gold_version)

    train_labels = load_squad_v2_label(train_gold_path)
    dev_labels = load_squad_v2_label(dev_gold_path)
    #train_labels = load_squad_v2_label(train_gold_path)

    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    train_data = BatchGen(train_path,
                          batch_size=args.batch_size,
                          gpu=args.cuda,
                          with_label=args.v2_on,
                          elmo_on=args.elmo_on)
    dev_data = BatchGen(dev_path,
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False,
                        elmo_on=args.elmo_on)

    test_data = None
    test_gold = None

    if os.path.exists(test_path):
        test_data = BatchGen(test_path,
                             batch_size=args.batch_size,
                             gpu=args.cuda,
                             is_train=False,
                             elmo_on=args.elmo_on)

    # load golden standard
    train_gold = load_squad(train_gold_path)
    dev_gold = load_squad(dev_gold_path)
    #train_gold = load_squad(train_gold_path)

    if os.path.exists(test_gold_path):
        test_gold = load_squad(test_gold_path)

    #define csv path
    csv_head = [
        'epoch', 'train_loss', 'train_loss_san', 'train_loss_class', 'dev_em',
        'dev_f1', 'dev_acc', 'train_em', 'train_f1', 'train_acc'
    ]
    csvfile = 'results_{}.csv'.format(args.classifier_gamma)
    csv_path = os.path.join(args.data_dir, csvfile)
    result_params = []

    #load previous checkpoint
    start_epoch = 0
    state_dict = None

    if (args.load_checkpoint != 0):
        start_epoch = args.load_checkpoint + 1
        checkpoint_file = 'checkpoint_{}_epoch_{}.pt'.format(
            version, args.load_checkpoint)
        checkpoint_path = os.path.join(args.model_dir, checkpoint_file)
        logger.info('path to prev checkpoint is {}'.format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        state_dict = checkpoint['state_dict']
        opt = checkpoint['config']
        #logger.warning('the checkpoint is {}'.format(checkpoint))

        #load previous metrics
        with open(csv_path, 'r') as csvfile:
            csvreader = csv.reader(csvfile)
            dummy = next(csvreader)
            for row in csvreader:
                result_params.append(row)

        logger.info('Previous metrics loaded')

    model = DocReaderModel(opt, embedding, state_dict)
    # model meta str
    #headline = '############# Model Arch of SAN #############'
    # print network
    #logger.info('\n{}\n{}\n'.format(headline, model.network))
    model.setup_eval_embed(embedding)

    logger.info("Total number of params: {}".format(model.total_param))
    if args.cuda:
        model.cuda()

    best_em_score, best_f1_score = 0.0, 0.0

    for epoch in range(start_epoch, args.epoches):
        logger.warning('At epoch {}'.format(epoch))

        loss, loss_san, loss_class = 0.0, 0.0, 0.0

        train_data.reset()
        start = datetime.now()
        for i, batch in enumerate(train_data):
            losses = model.update(batch)
            loss += losses[0].item()
            loss_san += losses[1].item()
            if losses[2]:
                loss_class += losses[2].item()

            if (model.updates) % args.log_per_updates == 0 or i == 0:
                logger.info(
                    '#updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.
                    format(
                        model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(train_data) - i - 1)).split('.')[0]))

        # train eval
        tr_results, tr_labels = predict_squad(model,
                                              train_data,
                                              v2_on=args.v2_on)
        if args.v2_on and args.classifier_on:
            train_metric = evaluate_v2(
                train_gold,
                tr_results,
                na_prob_thresh=args.classifier_threshold)
            train_em, train_f1 = train_metric['exact'], train_metric['f1']
            train_acc = compute_acc(tr_labels, train_labels)
        else:
            train_metric = evaluate(train_gold, tr_results)
            train_em, train_f1 = train_metric['exact_match'], train_metric[
                'f1']
            train_acc = -1

        # dev eval
        results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
        if args.v2_on and args.classifier_on:
            metric = evaluate_v2(dev_gold,
                                 results,
                                 na_prob_thresh=args.classifier_threshold)
            em, f1 = metric['exact'], metric['f1']
            acc = compute_acc(labels, dev_labels)
        else:
            metric = evaluate(dev_gold, results)
            em, f1 = metric['exact_match'], metric['f1']
            acc = -1

        output_path = os.path.join(model_dir,
                                   'dev_output_{}.json'.format(epoch))
        with open(output_path, 'w') as f:
            json.dump(results, f)

        if test_data is not None:
            test_results, test_labels = predict_squad(model,
                                                      test_data,
                                                      v2_on=args.v2_on)
            test_output_path = os.path.join(
                model_dir, 'test_output_{}.json'.format(epoch))
            with open(test_output_path, 'w') as f:
                json.dump(test_results, f)

            if (test_gold is not None):
                if args.v2_on:
                    test_metric = evaluate_v2(
                        test_gold,
                        test_results,
                        na_prob_thresh=args.classifier_threshold)
                    test_em, test_f1 = test_metric['exact'], test_metric['f1']
                    test_acc = compute_acc(
                        labels, test_labels
                    )  #?? should be test_labels,test_gold_labels
                else:
                    test_metric = evaluate(test_gold, test_results)
                    test_em, test_f1 = test_metric['exact_match'], test_metric[
                        'f1']

        # setting up scheduler
        # halves learning rate every 10 epochs
        if model.scheduler is not None:
            logger.info('scheduler_type {}'.format(opt['scheduler_type']))
            if opt['scheduler_type'] == 'rop':
                model.scheduler.step(f1, epoch=epoch)
            else:
                model.scheduler.step()
        # save
        model_file = os.path.join(
            model_dir, 'checkpoint_{}_epoch_{}.pt'.format(version, epoch))

        model.save(model_file, epoch)
        if em + f1 > best_em_score + best_f1_score:
            copyfile(
                os.path.join(model_dir, model_file),
                os.path.join(model_dir,
                             'best_{}_checkpoint.pt'.format(version)))
            best_em_score, best_f1_score = em, f1
            logger.info('Saved the new best model and prediction')

        approx = lambda x: round(x, 3)

        logger.warning(f""" Epoch {str(epoch).zfill(2)} ---
        Train | acc: {approx(train_acc)} EM: {approx(train_em)} F1: {approx(train_f1)} loss ({approx(loss)}) = {approx(loss_san)} + {approx(loss_class)}
        Dev   | acc: {approx(acc)} EM: {approx(em)} F1: {approx(f1)}
        --------------------------------
        """)

        #writing in CSV
        result_params.append([
            epoch, loss, loss_san, loss_class, em, f1, acc, train_em, train_f1,
            train_acc
        ])
        logger.info('Writing in {} the values {}'.format(
            csv_path, result_params))
        with open(csv_path, 'w') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(csv_head)
            csvwriter.writerows(result_params)
Ejemplo n.º 6
0
def main():
    logger.info('Launching the SAN')
    opt = vars(args)
    logger.info('Loading data')
    version = 'v1'
    if args.v2_on:
        version = 'v2'
        dev_labels = load_squad_v2_label(args.dev_gold)
        dev_labels_adv = load_squad_v2_label('data/adv-dev-v2.0.json')
    embedding, opt = load_meta(
        opt, gen_name(args.data_dir, args.meta, version, suffix='pick'))
    dev_data = BatchGen(gen_name(args.data_dir, args.dev_data, version),
                        batch_size=args.batch_size,
                        gpu=args.cuda,
                        is_train=False)
    dev_data_adv = BatchGen(gen_name(args.data_dir, 'adv_' + args.dev_data,
                                     version),
                            batch_size=args.batch_size,
                            gpu=args.cuda,
                            is_train=False)

    # load golden standard
    dev_gold = load_squad(args.dev_gold)
    dev_gold_adv = load_squad('data/adv-dev-v2.0.json')

    # TODO
    best_checkpoint_path = os.path.join(
        model_dir, 'best_{}_checkpoint.pt'.format(version))
    check = torch.load(best_checkpoint_path)
    model = DocReaderModel(check['config'],
                           embedding,
                           state_dict=check['state_dict'])
    model.setup_eval_embed(embedding)

    if args.cuda:
        model.cuda()

    # dev eval
    results, labels = predict_squad(model, dev_data, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold,
                             results,
                             labels,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels)
        print("Original validation EM {}, F1 {}, Acc {}".format(em, f1, acc))
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']

    results, labels = predict_squad(model, dev_data_adv, v2_on=args.v2_on)
    if args.v2_on:
        metric = evaluate_v2(dev_gold_adv,
                             results,
                             labels,
                             na_prob_thresh=args.classifier_threshold)
        em, f1 = metric['exact'], metric['f1']
        acc = compute_acc(labels, dev_labels_adv)
        print("Adversarial EM {}, F1 {}, Acc {}".format(em, f1, acc))
    else:
        metric = evaluate(dev_gold, results)
        em, f1 = metric['exact_match'], metric['f1']