Exemplo n.º 1
0
def start_training(args, parse_args_to_params_fun, TrainClass):

    if args.cluster:
        set_debug_level(2)
        loss_freq = 250
    else:
        set_debug_level(0)
        loss_freq = 2

    if args.load_config:
        if args.checkpoint_path is None:
            print(
                "[!] ERROR: Please specify the checkpoint path to load the config from."
            )
            sys.exit(1)
        args = load_args(args.checkpoint_path)
        args.clean_up = False

    # Setup training
    model_params, optimizer_params = parse_args_to_params_fun(args)
    trainModule = TrainClass(model_params=model_params,
                             optimizer_params=optimizer_params,
                             batch_size=args.batch_size,
                             checkpoint_path=args.checkpoint_path,
                             debug=args.debug)

    def clean_up_dir():
        print("Cleaning up directory " + str(trainModule.checkpoint_path) +
              "...")
        for file_in_dir in sorted(
                glob(os.path.join(trainModule.checkpoint_path, "*"))):
            print("Removing file " + file_in_dir)
            try:
                if os.path.isfile(file_in_dir):
                    os.remove(file_in_dir)
                elif os.path.isdir(file_in_dir):
                    shutil.rmtree(file_in_dir)
            except Exception as e:
                print(e)

    if args.restart and args.checkpoint_path is not None and os.path.isdir(
            args.checkpoint_path):
        clean_up_dir()

    args_filename = os.path.join(trainModule.checkpoint_path,
                                 PARAM_CONFIG_FILE)
    with open(args_filename, "wb") as f:
        pickle.dump(args, f)

    trainModule.train_model(args.max_iterations,
                            loss_freq=loss_freq,
                            eval_freq=args.eval_freq,
                            save_freq=args.save_freq,
                            no_model_checkpoints=args.no_model_checkpoints)

    if args.clean_up:
        clean_up_dir()
        os.rmdir(trainModule.checkpoint_path)
Exemplo n.º 2
0
def load_our_model(checkpoint_path):
    global OUR_MODEL
    if OUR_MODEL is None:
        args = load_args(checkpoint_path)

        print("-> Loading model...")
        model_params, _ = unsupervised_args_to_params(args)

        _, _, wordvec_tensor = load_word2vec_from_file()
        model = ModelUnsupervisedContextParaphrasingTemplate(
            model_params, wordvec_tensor)

        print(checkpoint_path)
        _ = load_model(checkpoint_path, model=model, load_best_model=True)
        model = model.to(get_device())

        model.eval()

        OUR_MODEL = model
    return OUR_MODEL
Exemplo n.º 3
0
        with open(output_file, "w") as f:
            f.write(out_s)

    print(out_s)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_path",
                        help="Folder(name) where checkpoints are saved",
                        type=str,
                        required=True)
    parser.add_argument(
        "--input_file",
        help=
        "Input file which contains the sentences. Format of each line: \"premise\" #SEP# \"hypothesis\"",
        type=str,
        required=True)
    parser.add_argument(
        "--output_file",
        help=
        "File to which the predictions should be written out. Default: infer_out.txt",
        type=str,
        default="infer_out.txt")
    args = parser.parse_args()
    model = load_model_from_args(load_args(args.checkpoint_path),
                                 args.checkpoint_path,
                                 load_best_model=True)

    run_inference(model, args.input_file, args.output_file, load_file=True)
Exemplo n.º 4
0
	# parser.add_argument("--all", help="Evaluating all experiments in the checkpoint folder (specified by checkpoint path) if not already done", action="store_true")
	args = parser.parse_args()
	model_list = sorted(glob(args.checkpoint_path))
	transfer_datasets = get_transfer_datasets()

	for model_checkpoint in model_list:
		# if not os.path.isfile(os.path.join(model_checkpoint, "results.txt")):
		# 	print("Skipped " + str(model_checkpoint) + " because of missing results file." )
		# 	continue
		
		skip_standard_eval = not args.overwrite and os.path.isfile(os.path.join(model_checkpoint, "evaluation.pik"))
		skip_sent_eval = not args.overwrite and os.path.isfile(os.path.join(model_checkpoint, "sent_eval.pik"))
		skip_extra_eval = (not args.overwrite and os.path.isfile(os.path.join(model_checkpoint, "extra_evaluation.txt")))

		try:
			model, tasks = load_model_from_args(load_args(model_checkpoint))
			evaluater = MultiTaskEval(model, tasks)
			evaluater.dump_errors(model_checkpoint, task_name=evaluater.tasks[0].name)
			evaluater.test_best_model(model_checkpoint, 
									  run_standard_eval=(not skip_standard_eval), 
									  run_training_set=False,
									  run_sent_eval=(not skip_sent_eval),
									  run_extra_eval=(not skip_extra_eval),
									  light_senteval=(not args.full_senteval))
			if args.visualize_embeddings:
				evaluater.visualize_tensorboard(model_checkpoint, replace_old_files=args.overwrite, additional_datasets=transfer_datasets)
		except RuntimeError as e:
			print("[!] Runtime error while loading " + model_checkpoint)
			print(e)
			continue
	# evaluater.evaluate_all_models(args.checkpoint_path)
Exemplo n.º 5
0
	parser.add_argument("--momentum", help="Apply momentum to SGD optimizer", type=float, default=0.0)

	args = parser.parse_args()
	print(args)
	if args.cluster:
		set_debug_level(2)
		loss_freq = 500
	else:
		set_debug_level(0)
		loss_freq = 50

	if args.load_config:
		if args.checkpoint_path is None:
			print("[!] ERROR: Please specify the checkpoint path to load the config from.")
			sys.exit(1)
		args = load_args(args.checkpoint_path)

	# Setup training
	tasks, model_type, model_params, optimizer_params, multitask_params = args_to_params(args)
	trainModule = MultiTaskTrain(tasks=tasks,
								 model_type=args.model, 
								 model_params=model_params,
								 optimizer_params=optimizer_params, 
								 multitask_params=multitask_params,
								 batch_size=args.batch_size,
								 checkpoint_path=args.checkpoint_path, 
								 debug=args.debug
								 )

	if args.restart and args.checkpoint_path is not None and os.path.isdir(args.checkpoint_path):
		print("Cleaning up directiory " + str(args.checkpoint_path) + "...")