def run(config_path, config_string, name): runs = RetrieverTrainingRuns(check_commit=False) config = Config.from_file(config_path) if config_string: config = Config.merge([config, Config.from_str(config_string)]) run = runs.new(config, name) run.train()
arg_parser = argparse.ArgumentParser() arg_parser.add_argument('exp_id', nargs='+') arg_parser.add_argument('-c', '--check_commit', default='strict') arg_parser.add_argument('-p', '--profile', action='store_true') args = arg_parser.parse_args() # create experiment experiments = EditTrainingRuns( check_commit=(args.check_commit == 'disable')) #'strict' wyl exp_id = args.exp_id if exp_id == ['default']: # new default experiment exp = experiments.new() elif len(exp_id) == 1 and exp_id[0].isdigit(): # reload old experiment exp = experiments[int(exp_id[0])] else: # new experiment according to configs config = Config.from_file(exp_id[0]) for filename in exp_id[1:]: config = Config.merge(config, Config.from_file(filename)) exp = experiments.new(config) # new experiment from config # start training exp.workspace.add_file('stdout', 'stdout.txt') exp.workspace.add_file('stderr', 'stderr.txt') with save_stdout(exp.workspace.root): exp.train()
np.random.seed(args.seed) torch.manual_seed(args.seed) # create run runs = HRLTrainingRuns(check_commit=(args.check_commit == 'strict')) config_paths = args.config_paths if len(config_paths) == 1 and config_paths[0].isdigit(): configs = [Config.from_file(p) for p in args.reward_configs] run = runs.clone(int(config_paths[0]), configs, args.name) else: # new run according to configs configs = [Config.from_file(p) for p in config_paths] # merge all configs together config = Config.merge(configs) # later configs overwrite earlier configs run = runs.new(config, name=args.name) # new run from config run.metadata['description'] = args.description run.metadata['name'] = args.name run.metadata['host'] = socket.gethostname() # start training run.workspace.add_file('stdout', 'stdout.txt') run.workspace.add_file('stderr', 'stderr.txt') with save_stdout(run.workspace.root): try: run.train() finally:
eval_samples = args.eval_samples # create experiment experiments = Experiments(check_commit=args.check_commit=='strict') if exp_id == ['default']: # new default experiment exp = experiments.new() elif len(exp_id) == 1 and exp_id[0].isdigit(): # reload old experiment exp = experiments[int(exp_id[0])] else: # new experiment according to configs config = Config.from_file(exp_id[0]) for filename in exp_id[1:]: config = Config.merge(config, Config.from_file(filename)) exp = experiments.new(config) # new experiment from config # add experiment to tracker if args.tracker: exp_type, dataset, seed = ExperimentType.parse_configs(exp_id) with TopLevelTracker(args.tracker) as tracker: tracker.register_result( dataset, exp_type, seed, exp.workspace.root) ################################ # Profiling # from gtd.chrono import Profiling, Profiler # profiler = Profiler.default() # import gtd.ml.seq_batch; profiler.add_module(gtd.ml.seq_batch) # import strongsup.decoder; profiler.add_module(strongsup.decoder)