def main(env, ctrl_type, ctrl_args, overrides, logdir): set_global_seeds(0) ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir) cfg.pprint() assert ctrl_type == 'MPC' cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) exp = MBExperiment(cfg.exp_cfg) os.makedirs(exp.logdir) with open(os.path.join(exp.logdir, "config.txt"), "w") as f: f.write(pprint.pformat(cfg.toDict())) exp.run_experiment()
def main(args): #set_global_seeds(0) cfg = create_config(args) cfg.pprint() assert args.ctrl_type == 'MPC' cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) exp = MBExperiment(cfg.exp_cfg) if args.load_model_dir is not None: exp.policy.model.load_state_dict( torch.load(os.path.join(args.load_model_dir, 'weights'))) if not os.path.exists(exp.logdir): os.makedirs(exp.logdir) with open(os.path.join(exp.logdir, "config.txt"), "w") as f: f.write(pprint.pformat(cfg.toDict())) exp.run_experiment()
def main(env, ctrl_type, ctrl_args, overrides, model_dir, logdir): ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.model_dir", model_dir]) overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.load_model", "True"]) overrides.append(["ctrl_cfg.prop_cfg.model_pretrained", "True"]) overrides.append(["exp_cfg.exp_cfg.ninit_rollouts", "0"]) overrides.append(["exp_cfg.exp_cfg.ntrain_iters", "1"]) overrides.append(["exp_cfg.log_cfg.nrecord", "1"]) cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir) cfg.pprint() if ctrl_type == "MPC": cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) exp = MBExperiment(cfg.exp_cfg) os.makedirs(exp.logdir) with open(os.path.join(exp.logdir, "config.txt"), "w") as f: f.write(pprint.pformat(cfg.toDict())) exp.run_experiment()
def main(args): set_global_seeds(0) cfg = create_config(args) cfg.pprint() # Set env for PointmassEnv if (isinstance(cfg.ctrl_cfg.env, PointmassEnv)): # Change optimizer to discrete CEM cfg.ctrl_cfg.opt_cfg.mode = 'DCEM' #assert args.ctrl_type == 'MPC' if args.ctrl_type == 'PuP': print("Using Pets-using-Pets Policy.") cfg.exp_cfg.exp_cfg.policy = ExploreEnsembleVarianceMPC(cfg.ctrl_cfg) elif args.ctrl_type == 'RND': assert False, "JL: Not implemented fully yet!" print("Using RND Policy.") cfg.exp_cfg.exp_cfg.policy = ExploreRNDMPC(cfg.ctrl_cfg) else: print("Using default MPC Policy.") cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) exp = MBExperiment(cfg.exp_cfg) if args.load_model_dir is not None: exp.policy.model.load_state_dict( torch.load(os.path.join(args.load_model_dir, 'weights'))) if not os.path.exists(exp.logdir): os.makedirs(exp.logdir) os.makedirs(os.path.join(exp.logdir, "TRAIN")) os.makedirs(os.path.join(exp.logdir, "ADAPT")) with open(os.path.join(exp.logdir, "config.txt"), "w") as f: f.write(pprint.pformat(cfg.toDict())) exp.run_experiment()