Пример #1
0
                model_supervisor.model.lr_decay(args.lr_decay_rate)
                if model_supervisor.model.cont_prob > 0.01:
                    model_supervisor.model.cont_prob *= 0.5


def evaluate(args):
    print('Evaluation:')

    test_data = data_utils.load_dataset(args.test_dataset, args)
    test_data_size = len(test_data)
    args.dropout_rate = 0.0

    dataProcessor = data_utils.vrpDataProcessor()
    model_supervisor = create_model(args)
    test_loss, test_reward = model_supervisor.eval(test_data,
                                                   args.output_trace_flag)

    print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward))


if __name__ == "__main__":
    argParser = arguments.get_arg_parser("vrp")
    args = argParser.parse_args()
    args.cuda = not args.cpu and torch.cuda.is_available()
    random.seed(args.seed)
    np.random.seed(args.seed)
    if args.eval:
        evaluate(args)
    else:
        train(args)
Пример #2
0
				if model_supervisor.model.cont_prob > 0.01:
					model_supervisor.model.cont_prob *= 0.5


def evaluate(args):
	print('Evaluation:')

	test_data = data_utils.load_dataset(args.test_dataset, args)
	test_data_size = len(test_data)
	args.dropout_rate = 0.0

	dataProcessor = data_utils.jspDataProcessor(args)
	model_supervisor = create_model(args)
	test_loss, test_reward = model_supervisor.eval(test_data, args.output_trace_flag)
	

	print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward))


if __name__ == "__main__":
	argParser = arguments.get_arg_parser("jsp")
	args = argParser.parse_args()
	args.cuda = not args.cpu and torch.cuda.is_available()
	random.seed(args.seed)
	np.random.seed(args.seed)
	if args.eval:
		evaluate(args)
	else:
		train(args)

Пример #3
0
	test_data = data_utils.load_dataset(args.test_dataset, args)
	test_data_size = len(test_data)

	args.dropout_rate = 0.0

	DataProcessor = data_utils.HalideDataProcessor()

	if args.test_min_len is not None:
		test_data = DataProcessor.prune_dataset(test_data, min_len=args.test_min_len)

	term_vocab, term_vocab_list = DataProcessor.load_term_vocab()
	op_vocab, op_vocab_list = DataProcessor.load_ops()
	args.term_vocab_size = len(term_vocab)
	args.op_vocab_size = len(op_vocab)
	model_supervisor = create_model(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list)
	test_loss, test_reward = model_supervisor.eval(test_data, args.output_trace_flag, args.output_trace_option, args.output_trace_file)

	print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward))


if __name__ == "__main__":
	argParser = arguments.get_arg_parser("Halide")
	args = argParser.parse_args()
	args.cuda = not args.cpu and torch.cuda.is_available()
	random.seed(args.seed)
	np.random.seed(args.seed)
	if args.eval:
		evaluate(args)
	else:
		train(args)
Пример #4
0
            print('target: ', gt_prog)

    print(
        'test loss: %.4f test label acc: %.4f test data acc: %.4f test acc: %.4f '
        % (test_loss, test_label_acc, test_data_acc, test_acc))
    print('Unpredictable samples: %d %.4f' %
          (cnt_unpredictable, cnt_unpredictable * 1.0 / len(test_data)))
    print('Upper bound: %.4f' % (1 - cnt_unpredictable * 1.0 / len(test_data)))
    for i in range(args.num_plot_types):
        print('cnt per category: ', i, cnt_per_category[i])
        if cnt_per_category[i] == 0:
            continue
        print('label acc per category: ', i, label_acc_per_category[i],
              label_acc_per_category[i] * 1.0 / cnt_per_category[i])
        print('data acc per category: ', i, data_acc_per_category[i],
              data_acc_per_category[i] * 1.0 / cnt_per_category[i])
        print('acc per category: ', i, acc_per_category[i],
              acc_per_category[i] * 1.0 / cnt_per_category[i])


if __name__ == "__main__":
    arg_parser = arguments.get_arg_parser('juice')
    args = arg_parser.parse_args()
    args.cuda = not args.cpu and torch.cuda.is_available()
    random.seed(args.seed)
    np.random.seed(args.seed)
    if args.eval:
        evaluate(args)
    else:
        train(args)
Пример #5
0
    train_data, dev_data, m, sampler = train_start(args)
    reporter = tools.Reporter(log_interval=args.log_interval,
                              logdir=args.model_dir)
    for epoch in range(args.num_epochs):
        for batch_idx, batch in enumerate(sampler):
            res = m.train(batch)
            reporter.record(m.last_step, **res)
            reporter.report()
            if m.last_step % args.eval_every_n == 0:
                m.model.eval()
                stats = {'correct': 0, 'total': 0}
                for dev_idx, dev_batch in enumerate(dev_data):
                    batch_res = m.eval(dev_batch)
                    stats['correct'] += batch_res['correct']
                    stats['total'] += batch_res['total']
                    if dev_idx > args.eval_n_steps:
                        break
                accuracy = float(stats['correct']) / stats['total']
                print("Dev accuracy: %.5f" % accuracy)
                reporter.record(m.last_step, **{'accuracy/dev': accuracy})
                m.model.train()


if __name__ == "__main__":
    parser = arguments.get_arg_parser('Training Text2Code', 'train')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    train(args)
Пример #6
0
def evaluate(args):
    print("Evaluation:")
    print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir))

    if args.ensemble_parameters:
        assert '#' in args.model_dir
        args_per_model = []
        for parameter in args.ensemble_parameters:
            globstring = args.model_dir.replace("#", parameter)
            paths = glob.glob(globstring)
            assert paths, globstring
            for model_dir in paths:
                print(model_dir)
                assert os.path.exists(model_dir), model_dir
                args_for = copy.deepcopy(args)
                args_for.model_dir = model_dir
                args_per_model.append(args_for)
    else:
        args_per_model = [copy.deepcopy(args)]

    for a in args_per_model:
        tools.restore_args(a)
        arguments.backport_default_args(a)
        datasets.set_vocab(a)

    args = args_per_model[0]

    ms = [models.get_model(a) for a in args_per_model]

    if args.iterative_search_use_overfit_model is not None:
        assert args.iterative_search is not None, "using an overfit model only makes sense if iterative search is being used"
        overfit_model_args = eval("dict({})".format(args.iterative_search_use_overfit_model))
        parsed = vars(arguments.get_arg_parser('overfit', 'eval').parse_args([]))
        parsed.update(overfit_model_args)
        parsed = argparse.Namespace(
            **parsed
        )
        tools.restore_args(parsed)
        arguments.backport_default_args(parsed)
        datasets.set_vocab(parsed)
        overfit_model = KarelLGRLOverfitModel(parsed)
        print("Overfit model")
        print(overfit_model.model)
    else:
        overfit_model = None

    if args.eval_final:
        eval_dataset = datasets.get_eval_final_dataset(args, ms[0])
    elif args.eval_train:
        eval_dataset, _ = datasets.get_dataset(args, ms[0], eval_on_train=True)
    else:
        eval_dataset = datasets.get_eval_dataset(args, ms[0])
    if any(m.last_step == 0 for m in ms):
        raise ValueError('Attempting to evaluate on untrained model')
    for m in ms:
        m.model.eval()
    current_executor = executor.get_executor(args)()
    if args.example_id is not None:
        eval_dataset.data = [eval_dataset.task[args.example_id]]

    inference = ensembled_inference([m.inference for m in ms], args.ensemble_mode)

    if isinstance(ms[0], KarelLGRLOverfitModel):
        assert len(ms) == 1
        evaluation.run_overfit_eval(
            eval_dataset, inference,
            args.report_path,
            limit=args.limit)

        return

    if args.iterative_search is not None:
        inference = IterativeSearch(inference,
                                    Strategy.get(args.iterative_search),
                                    current_executor,
                                    args.karel_trace_enc != 'none', ms[0].batch_processor(for_eval=True),
                                    start_with_beams=args.iterative_search_start_with_beams,
                                    time_limit=args.iterative_search_step_limit,
                                    overfit_model=overfit_model)
    if args.run_predict:
        evaluation.run_predict(eval_dataset, inference, current_executor.execute, args.predict_path,
                               evaluate_on_all=args.evaluate_on_all)
    else:
        evaluation.run_eval(
            args.tag, eval_dataset, inference,
            current_executor.execute, not args.hide_example_info,
            args.report_path,
            limit=args.limit,
            evaluate_on_all=args.evaluate_on_all)
Пример #7
0
    if args.iterative_search is not None:
        inference = IterativeSearch(inference,
                                    Strategy.get(args.iterative_search),
                                    current_executor,
                                    args.karel_trace_enc != 'none', ms[0].batch_processor(for_eval=True),
                                    start_with_beams=args.iterative_search_start_with_beams,
                                    time_limit=args.iterative_search_step_limit,
                                    overfit_model=overfit_model)
    if args.run_predict:
        evaluation.run_predict(eval_dataset, inference, current_executor.execute, args.predict_path,
                               evaluate_on_all=args.evaluate_on_all)
    else:
        evaluation.run_eval(
            args.tag, eval_dataset, inference,
            current_executor.execute, not args.hide_example_info,
            args.report_path,
            limit=args.limit,
            evaluate_on_all=args.evaluate_on_all)


if __name__ == "__main__":
    parser = arguments.get_arg_parser('Evaluating Text2Code', 'eval')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if not args.model_type or (not args.model_dir and args.model_type != 'search'):
        raise ValueError("Specify model_dir and model_type")
    if not args.tag:
        args.tag = args.model_type
    evaluate(args)