def main(args):
    set_global_seeds(args.seed)

    task_path = os.path.join(root_path, "task", args.task)
    data_path = os.path.join(task_path, "data")

    task_desc_path = os.path.join(task_path, "task.json")

    with open(task_desc_path, "r") as f:
        task_desc = json.load(f)

    e_path_async = os.path.join(data_path, args.e_filename)
    D_e_bc = np.load(e_path_async)

    data = {}
    data["irl"] = D_e_bc

    model_path = os.path.join(task_path, "model",
                              "{}.json".format(args.model_id))
    with open(model_path, "r") as f:
        params = json.load(f)

    if args.task_mode == "train":
        train(data, task_desc, params, args, task_path)

    elif args.task_mode == "evaluate":
        evaluate(data, task_desc, params, args, task_path)
Example #2
0
def train_bc(task, params, ob_space, ac_space, args, env):
    task_path = os.path.join(root_path, "task", args.task)
    plot_path = os.path.join(task_path, "result")

    dataset = GymDataset(expert_path=args.expert_path,
                         traj_limitation=args.traj_limitation)

    U.make_session(num_cpu=1).__enter__()
    set_global_seeds(args.seed)

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size_phi=args.policy_hidden_size,
                                    num_hid_layers_phi=2,
                                    dim_phi=args.dim_phi)

    env_name = task["env_id"]
    name = "pi.{}.{}".format(env_name.lower().split("-")[0],
                             args.traj_limitation)
    pi = policy_fn(name, ob_space, ac_space)
    n_action = env.action_space.n

    fname = "ckpt.bc.{}.{}".format(args.traj_limitation, args.seed)
    savedir_fname = osp.join(args.checkpoint_dir, fname, fname)

    if not os.path.exists(savedir_fname + ".index"):
        savedir_fname = learn(pi,
                              dataset,
                              env_name,
                              n_action,
                              prefix="bc",
                              seed=args.seed,
                              traj_lim=args.traj_limitation,
                              max_iters=args.BC_max_iter,
                              ckpt_dir=osp.join(args.checkpoint_dir, fname),
                              plot_dir=plot_path,
                              task_name=task["env_id"],
                              verbose=True)
        logger.log(savedir_fname + "saved")


#    avg_len, avg_ret = run_gym(env,
#                               policy_fn,
#                               savedir_fname,
#                               timesteps_per_batch=args.horizon,
#                               number_trajs=10,
#                               stochastic_policy=args.stochastic_policy,
#                               save=args.save_sample,
#                               reuse=True)
#
#
    return savedir_fname
def train(data, task_desc, params, args, task_path):
    import gym
    ob_dim = data["irl"]["ob_list"][0].shape[0]
    #ob_dim = ob_dim // 4
    c = np.max([
        np.abs(np.min(data["irl"]["ob_list"])),
        np.abs(np.max(data["irl"]["ob_list"]))
    ])
    ob_low = np.ones(ob_dim) * -c
    ob_high = np.ones(ob_dim) * c
    ob_space = gym.spaces.Box(low=ob_low, high=ob_high)
    n_action = 5
    ac_space = gym.spaces.Discrete(n=n_action)

    if args.pretrain:
        model_path = os.path.join(root_path, "task", args.task, "model")
        fname = "ckpt.bc.{}.{}".format(args.traj_limitation, args.seed)
        ckpt_dir = os.path.join(model_path, fname)
        pretrained_path = os.path.join(ckpt_dir, fname)

        if not os.path.exists(os.path.join(ckpt_dir, "checkpoint")):
            print("==== pretraining starts ===")
            pretrained_path = train_bc_sepsis(task_desc, params, ob_space,
                                              ac_space, args)

        U.make_session(num_cpu=1).__enter__()
        set_global_seeds(args.seed)

        def mlp_pi_wrapper(name, ob_space, ac_space, reuse=False):
            return mlp_policy.MlpPolicy(name=name,
                                        ob_space=ob_space,
                                        ac_space=ac_space,
                                        reuse=reuse,
                                        hid_size_phi=args.policy_hidden_size,
                                        num_hid_layers_phi=2,
                                        dim_phi=args.dim_phi)

        # just imitation learning
        #def mlp_pi_wrapper(name, ob_space, ac_space, reuse=False):
        #    return mlp_policy.MlpPolicyOriginal(name=name,
        #                                    ob_space=ob_space,
        #                                    ac_space=ac_space,
        #                                    reuse=reuse,
        #                                    hid_size=args.policy_hidden_size,
        #                                    num_hid_layers=2)

        env_name = task_desc["env_id"]
        scope_name = "pi.{}.{}".format(env_name.lower().split("-")[0],
                                       args.traj_limitation)

        pi_bc = mlp_pi_wrapper(scope_name, ob_space, ac_space)
        U.initialize()
        U.load_state(pretrained_path)
        phi_bc = pi_bc.featurize

        def phi_old(s, a):
            """
            TODO: if action is discrete
            one hot encode action and concatenate with phi(s)
            """
            # expect phi(s) -> (N, state_dim)
            # expect a -> (N, action_dim)
            phi_s = phi_bc(s)

            if len(phi_s.shape) == 1:
                # s -> (1, state_dim)
                phi_s = np.expand_dims(phi_s, axis=0)

            # if a = 5
            try:
                if a == int(a):
                    a = [a]
            except:
                pass

            a = np.array(a)
            # if a = [5]
            if len(a.shape) == 1:
                a = np.expand_dims(a, axis=1)
            # otherwise if a = [[5], [3]]
            phi_sa = np.hstack((phi_s, a))
            return phi_sa

        def phi_discrete_action(n_action):
            def f(s, a):
                # expect phi(s) -> (N, state_dim)
                # expect a -> (N, action_dim)
                phi_s = phi_bc(s)

                try:
                    if a == int(a):
                        a = [a]
                except:
                    pass
                a = np.array(a)

                a_onehot = np.eye(n_action)[a.astype(int)]

                if len(phi_s.shape) == 1:
                    # s -> (1, state_dim)
                    phi_s = np.expand_dims(phi_s, axis=0)

                try:
                    phi_sa = np.hstack((phi_s, a_onehot))
                except:
                    a_onehot = a_onehot.reshape(a_onehot.shape[0],
                                                a_onehot.shape[2])
                    phi_sa = np.hstack((phi_s, a_onehot))
                return phi_sa

            return f

        if isinstance(ac_space, gym.spaces.Discrete):
            phi = phi_discrete_action(ac_space.n)
        elif isinstance(ac_space, gym.spaces.Box):
            phi = phi_continuous_action
        else:
            raise NotImplementedError

    D = data["irl"]
    obs = D["ob_list"].reshape(-1, D["ob_list"].shape[-1])
    obs_p1 = D["ob_next_list"].reshape(-1, D["ob_next_list"].shape[-1])
    #assuming action dof of 1
    acs = D["ac_list"].reshape(-1)
    new = D["new"].reshape(-1)

    data = {}
    data["s"] = obs
    data["a"] = acs
    data["s_next"] = obs_p1
    data["done"] = data["absorb"] = new

    data["phi_sa"] = phi(obs, acs)
    data["phi_fn"] = phi
    data["phi_fn_s"] = phi_bc

    data["psi_sa"] = data["phi_sa"]
    data["psi_fn"] = phi

    evaluator = ALEvaluator(data, task_desc["gamma"], env=None)
    data_path = os.path.join(task_path, "data")

    pi_0 = pi_bc

    phi_dim = data["phi_sa"].shape[1]

    model_id = "{}.{}".format(params["id"], params["version"])

    if model_id == "mma.0":
        result = train_mma(pi_0, phi_dim, task_desc, params, data, evaluator,
                           ob_space, ac_space)
    elif model_id == "mma.1":
        result = train_mma(pi_0, phi_dim, task_desc, params, data, evaluator,
                           ob_space, ac_space)
    elif model_id == "mma.2":
        #result = train_scirl_v2(data, phi_bc, evaluator, phi_dim, task_desc, params)
        #result = train_scirl_v3(data, phi_bc, evaluator)
        result = train_scirl(data, phi_bc, evaluator)
    else:
        raise NotImplementedError

    name = "{}.{}.{}".format(model_id, args.n_e, args.seed)
    result_path = os.path.join(args.save_path, name + "train.log")

    with open(result_path, "w") as f:
        #flush?
        for step in range(params["n_iteration"] + 1):
            data_points = [
                step,
                round(result["margin_mu"][step], 2),
                round(result["margin_v"][step], 2),
                round(result["a_match"][step], 2)
            ]
            f.write("{}\t{}\t{}\t{}\n".format(*data_points))

    with open(os.path.join(args.save_path, name + ".pkl"), "wb") as f:
        pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)