def main(args): set_global_seeds(args.seed) task_path = os.path.join(root_path, "task", args.task) data_path = os.path.join(task_path, "data") task_desc_path = os.path.join(task_path, "task.json") with open(task_desc_path, "r") as f: task_desc = json.load(f) e_path_async = os.path.join(data_path, args.e_filename) D_e_bc = np.load(e_path_async) data = {} data["irl"] = D_e_bc model_path = os.path.join(task_path, "model", "{}.json".format(args.model_id)) with open(model_path, "r") as f: params = json.load(f) if args.task_mode == "train": train(data, task_desc, params, args, task_path) elif args.task_mode == "evaluate": evaluate(data, task_desc, params, args, task_path)
def train_bc(task, params, ob_space, ac_space, args, env): task_path = os.path.join(root_path, "task", args.task) plot_path = os.path.join(task_path, "result") dataset = GymDataset(expert_path=args.expert_path, traj_limitation=args.traj_limitation) U.make_session(num_cpu=1).__enter__() set_global_seeds(args.seed) def policy_fn(name, ob_space, ac_space, reuse=False): return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse, hid_size_phi=args.policy_hidden_size, num_hid_layers_phi=2, dim_phi=args.dim_phi) env_name = task["env_id"] name = "pi.{}.{}".format(env_name.lower().split("-")[0], args.traj_limitation) pi = policy_fn(name, ob_space, ac_space) n_action = env.action_space.n fname = "ckpt.bc.{}.{}".format(args.traj_limitation, args.seed) savedir_fname = osp.join(args.checkpoint_dir, fname, fname) if not os.path.exists(savedir_fname + ".index"): savedir_fname = learn(pi, dataset, env_name, n_action, prefix="bc", seed=args.seed, traj_lim=args.traj_limitation, max_iters=args.BC_max_iter, ckpt_dir=osp.join(args.checkpoint_dir, fname), plot_dir=plot_path, task_name=task["env_id"], verbose=True) logger.log(savedir_fname + "saved") # avg_len, avg_ret = run_gym(env, # policy_fn, # savedir_fname, # timesteps_per_batch=args.horizon, # number_trajs=10, # stochastic_policy=args.stochastic_policy, # save=args.save_sample, # reuse=True) # # return savedir_fname
def train(data, task_desc, params, args, task_path): import gym ob_dim = data["irl"]["ob_list"][0].shape[0] #ob_dim = ob_dim // 4 c = np.max([ np.abs(np.min(data["irl"]["ob_list"])), np.abs(np.max(data["irl"]["ob_list"])) ]) ob_low = np.ones(ob_dim) * -c ob_high = np.ones(ob_dim) * c ob_space = gym.spaces.Box(low=ob_low, high=ob_high) n_action = 5 ac_space = gym.spaces.Discrete(n=n_action) if args.pretrain: model_path = os.path.join(root_path, "task", args.task, "model") fname = "ckpt.bc.{}.{}".format(args.traj_limitation, args.seed) ckpt_dir = os.path.join(model_path, fname) pretrained_path = os.path.join(ckpt_dir, fname) if not os.path.exists(os.path.join(ckpt_dir, "checkpoint")): print("==== pretraining starts ===") pretrained_path = train_bc_sepsis(task_desc, params, ob_space, ac_space, args) U.make_session(num_cpu=1).__enter__() set_global_seeds(args.seed) def mlp_pi_wrapper(name, ob_space, ac_space, reuse=False): return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse, hid_size_phi=args.policy_hidden_size, num_hid_layers_phi=2, dim_phi=args.dim_phi) # just imitation learning #def mlp_pi_wrapper(name, ob_space, ac_space, reuse=False): # return mlp_policy.MlpPolicyOriginal(name=name, # ob_space=ob_space, # ac_space=ac_space, # reuse=reuse, # hid_size=args.policy_hidden_size, # num_hid_layers=2) env_name = task_desc["env_id"] scope_name = "pi.{}.{}".format(env_name.lower().split("-")[0], args.traj_limitation) pi_bc = mlp_pi_wrapper(scope_name, ob_space, ac_space) U.initialize() U.load_state(pretrained_path) phi_bc = pi_bc.featurize def phi_old(s, a): """ TODO: if action is discrete one hot encode action and concatenate with phi(s) """ # expect phi(s) -> (N, state_dim) # expect a -> (N, action_dim) phi_s = phi_bc(s) if len(phi_s.shape) == 1: # s -> (1, state_dim) phi_s = np.expand_dims(phi_s, axis=0) # if a = 5 try: if a == int(a): a = [a] except: pass a = np.array(a) # if a = [5] if len(a.shape) == 1: a = np.expand_dims(a, axis=1) # otherwise if a = [[5], [3]] phi_sa = np.hstack((phi_s, a)) return phi_sa def phi_discrete_action(n_action): def f(s, a): # expect phi(s) -> (N, state_dim) # expect a -> (N, action_dim) phi_s = phi_bc(s) try: if a == int(a): a = [a] except: pass a = np.array(a) a_onehot = np.eye(n_action)[a.astype(int)] if len(phi_s.shape) == 1: # s -> (1, state_dim) phi_s = np.expand_dims(phi_s, axis=0) try: phi_sa = np.hstack((phi_s, a_onehot)) except: a_onehot = a_onehot.reshape(a_onehot.shape[0], a_onehot.shape[2]) phi_sa = np.hstack((phi_s, a_onehot)) return phi_sa return f if isinstance(ac_space, gym.spaces.Discrete): phi = phi_discrete_action(ac_space.n) elif isinstance(ac_space, gym.spaces.Box): phi = phi_continuous_action else: raise NotImplementedError D = data["irl"] obs = D["ob_list"].reshape(-1, D["ob_list"].shape[-1]) obs_p1 = D["ob_next_list"].reshape(-1, D["ob_next_list"].shape[-1]) #assuming action dof of 1 acs = D["ac_list"].reshape(-1) new = D["new"].reshape(-1) data = {} data["s"] = obs data["a"] = acs data["s_next"] = obs_p1 data["done"] = data["absorb"] = new data["phi_sa"] = phi(obs, acs) data["phi_fn"] = phi data["phi_fn_s"] = phi_bc data["psi_sa"] = data["phi_sa"] data["psi_fn"] = phi evaluator = ALEvaluator(data, task_desc["gamma"], env=None) data_path = os.path.join(task_path, "data") pi_0 = pi_bc phi_dim = data["phi_sa"].shape[1] model_id = "{}.{}".format(params["id"], params["version"]) if model_id == "mma.0": result = train_mma(pi_0, phi_dim, task_desc, params, data, evaluator, ob_space, ac_space) elif model_id == "mma.1": result = train_mma(pi_0, phi_dim, task_desc, params, data, evaluator, ob_space, ac_space) elif model_id == "mma.2": #result = train_scirl_v2(data, phi_bc, evaluator, phi_dim, task_desc, params) #result = train_scirl_v3(data, phi_bc, evaluator) result = train_scirl(data, phi_bc, evaluator) else: raise NotImplementedError name = "{}.{}.{}".format(model_id, args.n_e, args.seed) result_path = os.path.join(args.save_path, name + "train.log") with open(result_path, "w") as f: #flush? for step in range(params["n_iteration"] + 1): data_points = [ step, round(result["margin_mu"][step], 2), round(result["margin_v"][step], 2), round(result["a_match"][step], 2) ] f.write("{}\t{}\t{}\t{}\n".format(*data_points)) with open(os.path.join(args.save_path, name + ".pkl"), "wb") as f: pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)