def main(load_memory, generate_envs, feature_str, gamma, ftq_params, ftq_net_params, device, normalize_reward,
         workspace, seed,
         lambda_=0., **args):
    envs, params = envs_factory.generate_envs(**generate_envs)
    e = envs[0]
    set_seed(seed, e)

    feature = feature_factory(feature_str)

    ftq = PytorchFittedQ(
        device=device,
        policy_network=NetFTQ(n_in=len(feature(e.reset(), e)), n_out=e.action_space.n, **ftq_net_params),
        action_str=None if not hasattr(e, "action_str") else e.action_str,
        test_policy=None,
        gamma=gamma,
        **ftq_params
    )

    rm = Memory()
    rm.load_memory(**load_memory)

    transitions_ftq, _ = urpy.datas_to_transitions(rm.memory, e, feature, lambda_, normalize_reward)
    logger.info("[learning ftq with full batch] #samples={} ".format(len(transitions_ftq)))
    ftq.reset(True)
    ftq.workspace = workspace
    makedirs(ftq.workspace)
    ftq.fit(transitions_ftq)
    ftq.save_policy()
Example #2
0
 def save(self, path=None):
     import os
     makedirs(os.path.dirname(path))
     if path is None:
         path = self.workspace / "dqn.pt"
     logger.info("saving dqn at {}".format(path))
     torch.save(self.full_net, path)
     return path
Example #3
0
 def dump_to_workspace(self, filename="config.json"):
     """
     Dump the configuration a json file in the workspace.
     """
     makedirs(self.workspace)
     print(self.dict)
     with open(self.workspace / filename, 'w') as f:
         json.dump(self.dict, f, indent=2)
Example #4
0
def main(loss_function_str,
         optimizer_str,
         weight_decay,
         learning_rate,
         normalize,
         autoencoder_size,
         n_epochs,
         feature_autoencoder_info,
         workspace,
         device,
         type_ae="AEA",
         N_actions=None,
         writer=None):
    import torch
    loss_function = loss_fonction_factory(loss_function_str)
    makedirs(workspace)
    feature = build_feature_autoencoder(feature_autoencoder_info)
    min_n, max_n = autoencoder_size

    all_transitions = utils.read_samples_for_ae(workspace / "samples", feature,
                                                N_actions)

    autoencoders = [
        AutoEncoder(n_in=transitions.X.shape[1],
                    n_out=transitions.X.shape[1] *
                    (N_actions if type_ae == "AEA" else 1),
                    min_n=min_n,
                    max_n=max_n,
                    device=device) for transitions in all_transitions
    ]

    path_auto_encoders = workspace / "ae"
    makedirs(path_auto_encoders)
    print("learning_rate", learning_rate)
    print("optimizer_str", optimizer_str)
    print("weight_decay", weight_decay)
    # exit()
    for ienv, transitions in enumerate(all_transitions):
        autoencoders[ienv].reset()
        optimizer = optimizer_factory(optimizer_str,
                                      autoencoders[ienv].parameters(),
                                      lr=learning_rate,
                                      weight_decay=weight_decay)
        # for x,y in zip(transitions.X,transitions.A):
        #     print(x,"->",y)
        autoencoders[ienv].fit(transitions,
                               size_minibatch=all_transitions[ienv].X.shape[0],
                               n_epochs=n_epochs,
                               optimizer=optimizer,
                               normalize=normalize,
                               stop_loss=0.01,
                               loss_function=loss_function,
                               writer=writer)

        path_autoencoder = path_auto_encoders / "{}.pt".format(ienv)
        logger.info("saving autoencoder at {}".format(path_autoencoder))
        torch.save(autoencoders[ienv], path_autoencoder)
Example #5
0
    def load_tensorboardX(self):
        if self.is_tensorboardX:
            from tensorboardX import SummaryWriter
            empty_directory(self.workspace / "tensorboard")
            makedirs(self.workspace / "tensorboard")
            # exit()

            self.writer = SummaryWriter(str(self.workspace / "tensorboard"))
            command = "tensorboard --logdir {} --port 6009 &".format(
                str(self.workspace / "tensorboard"))
            self.logger.info("running command \"{}\"".format(command))
            os.system(command)
Example #6
0
def main(C,
         config_file,
         override_param_grid,
         override_device_str=None,
         f=lambda x: print("Hello")):
    with open(config_file, 'r') as infile:
        import json

        dict = json.load(infile)
        workspace = Path(dict["general"]["workspace"])
        makedirs(workspace)

    if "matplotlib_backend" in dict["general"]:
        backend = dict["general"]["matplotlib_backend"]
    else:
        backend = "Agg"

    logger.info("override device : {}".format(override_device_str))
    C.load_pytorch(override_device_str).load_matplotlib(backend)

    grid = ParameterGrid(override_param_grid)

    if os.path.exists(workspace / "params"):
        with open(workspace / "params", 'r') as infile:
            lines = infile.readlines()
            id_offset = re.match('^id=([0-9]+) ', lines[-1])
        id_offset = int(id_offset.group(1)) + 1
    else:
        id_offset = 0

    str_params = ""
    for i_config, params in enumerate(grid):
        str_params += "id=" + str(id_offset + i_config) + ' ' + ''.join(
            [k + "=" + str(v) + ' ' for k, v in params.items()]) + '\n'

    with open(workspace / "params", 'a') as infile:
        infile.write(str_params)

    for i_config, params in enumerate(grid):
        i_config = id_offset + i_config
        for k, v in params.items():
            keys = k.split('.')

            tochange = dict
            for ik in range(len(keys) - 1):
                tochange = tochange[keys[ik]]
            tochange[keys[-1]] = v
        dict["general"]["workspace"] = str(workspace / str(i_config))
        C.load(dict).create_fresh_workspace(force=True).load_tensorboardX()
        C.dump_to_workspace()

        print("\n-------- i_config={} ----------\n".format(i_config))
        f(C)
Example #7
0
 def save(self, path, as_json=True, indent=0):
     self.logger.info("saving memory at {}".format(path))
     makedirs(os.path.dirname(path))
     memory = [t._asdict() for t in self.memory]
     if as_json:
         with open(path , 'w') as f:
             if indent > 0:
                 json_str = json.dumps(memory, indent=indent)
             else:
                 json_str = json.dumps(memory)
             f.write(json_str)
     else:
         with open(path , 'wb') as f:
             pickle.dump(memory, f)
Example #8
0
def epsilon_decay(start=1.0, decay=0.01, N=100, savepath=None):
    makedirs(savepath)
    if decay == 0:
        decays = np.full(N, start)
    elif decay > 0:
        decays = np.exp(-np.arange(N) / (1. / decay)) * start
    else:
        raise Exception("Decay must be positive")
    str_decay = pretty_format_list(decays)
    logger.info("Epsilons (decayed) : [{}]".format(str_decay))
    if logger.getEffectiveLevel() <= logging.DEBUG:
        plt.plot(range(len(decays)), decays)
        plt.title("epsilon decays")
        plt.show()
        if savepath is not None:
            plt.savefig(Path(savepath) / "epsilon_decay")
        plt.close()
    return decays
Example #9
0
    def create_fresh_workspace(self, force=False):
        r = ''
        while r != 'y' and r != 'n':
            if force:
                r = 'y'
            else:
                r = input(
                    "are you sure you want to erase workspace {} [y/n] ?".
                    format(self.workspace))
            from ncarrara.utils.os import makedirs
            if r == 'y':
                self.__check__()
                self.__clean_workspace()

            elif r == 'n':
                makedirs(self.workspace)
            else:
                print("Only [y/n]")
        return self
Example #10
0
def main(load_memory, generate_envs, feature_str, gamma, gamma_c, bftq_params,
         bftq_net_params, workspace, seed, device, normalize_reward, general,
         **args):
    logger = logging.getLogger(__name__)

    envs, params = envs_factory.generate_envs(**generate_envs)
    e = envs[0]
    e.reset()
    set_seed(seed, e)
    feature = feature_factory(feature_str)

    bftq = PytorchBudgetedFittedQ(
        device=device,
        workspace=workspace,
        actions_str=get_actions_str(e),
        policy_network=NetBFTQ(size_state=len(feature(e.reset(), e)),
                               n_actions=e.action_space.n,
                               **bftq_net_params),
        gamma=gamma,
        gamma_c=gamma_c,
        split_batches=general["gpu"]["split_batches"],
        cpu_processes=general["cpu"]["processes"],
        env=e,
        **bftq_params)

    makedirs(workspace)
    rm = Memory()
    rm.load_memory(**load_memory)

    _, transitions_bftq = urpy.datas_to_transitions(rm.memory, e, feature, 0,
                                                    normalize_reward)
    logger.info("[learning bftq with full batch] #samples={} ".format(
        len(transitions_bftq)))

    bftq.reset(True)
    _ = bftq.fit(transitions_bftq)

    bftq.save_policy()
Example #11
0
def create_Q_histograms(title, values, path, labels, lims=(-1.1, 1.1)):
    makedirs(path)
    plt.clf()
    maxfreq = 0.
    lims = update_lims(lims, values)
    fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
    n, bins, patches = ax.hist(x=values,
                               label=labels,
                               alpha=1.,
                               stacked=False,
                               bins=np.linspace(
                                   *lims, 100))  # , alpha=0.7, rwidth=0.85)
    plt.grid(axis='y', alpha=0.75)

    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.legend(loc='upper right')
    if not os.path.exists(path):
        os.mkdir(path)
    plt.savefig(path / title)
    plt.show()
    plt.close()
Example #12
0
def fast_create_Q_histograms_for_actions(title,
                                         QQ,
                                         path,
                                         labels,
                                         mask_action=None,
                                         lims=(-1.1, 1.1)):
    makedirs(path)

    if mask_action is None:
        mask_action = np.zeros(len(QQ[0]))
    labs = []
    for i, label in enumerate(labels):
        # labels[i] = label[0:6]
        if mask_action[i] != 1:
            # labs.append(label[0:6])
            labs.append(label)
    Nact = len(mask_action)
    values = []
    for act in range(Nact):
        if mask_action[act] != 1:
            value = []
            for i in range(len(QQ)):
                value.append(QQ[i][act])
            values.append(value)

    lims = update_lims(lims, values)
    fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
    ax.hist(values, bins=np.linspace(*lims, 100), alpha=1.0, stacked=True)
    plt.grid(axis='y', alpha=0.75)
    plt.legend(labels)
    plt.title(title)
    # plt.show()
    if not os.path.exists(path):
        os.mkdir(path)
    plt.savefig(path / title)
    plt.show()
    plt.close()
Example #13
0
def main(generate_envs,
         feature_str,
         gamma,
         gamma_c,
         ftq_params,
         ftq_net_params,
         device,
         epsilon_decay,
         N_trajs,
         trajs_by_ftq_batch,
         normalize_reward,
         workspace,
         seed,
         save_memory,
         general,
         lambda_=0,
         **args):
    envs, params = envs_factory.generate_envs(**generate_envs)
    e = envs[0]
    set_seed(seed, e)
    rm = Memory()
    feature = feature_factory(feature_str)

    def build_fresh_ftq():
        ftq = PytorchFittedQ(
            device=device,
            policy_network=NetFTQ(n_in=len(feature(e.reset(), e)),
                                  n_out=e.action_space.n,
                                  **ftq_net_params),
            action_str=None if not hasattr(e, "action_str") else e.action_str,
            test_policy=None,
            gamma=gamma,
            **ftq_params)
        return ftq

    # Prepare learning
    i_traj = 0
    decays = math_utils.epsilon_decay(**epsilon_decay,
                                      N=N_trajs,
                                      savepath=workspace)
    batch_sizes = near_split(N_trajs, size_bins=trajs_by_ftq_batch)
    pi_epsilon_greedy_config = {
        "__class__": repr(EpsilonGreedyPolicy),
        "pi_greedy": {
            "__class__": repr(RandomPolicy)
        },
        "pi_random": {
            "__class__": repr(RandomPolicy)
        },
        "epsilon": decays[0]
    }

    # Main loop
    trajs = []
    for batch, batch_size in enumerate(batch_sizes):
        # Prepare workers
        cpu_processes = min(
            general["cpu"]["processes_when_linked_with_gpu"] or os.cpu_count(),
            batch_size)
        workers_n_trajectories = near_split(batch_size, cpu_processes)
        workers_start = np.cumsum(workers_n_trajectories)
        workers_traj_indexes = [
            np.arange(*times) for times in zip(
                np.insert(workers_start[:-1], 0, 0), workers_start)
        ]
        workers_seeds = np.random.randint(0, 10000, cpu_processes).tolist()
        workers_epsilons = [
            decays[i_traj + indexes] for indexes in workers_traj_indexes
        ]
        workers_params = list(
            zip_with_singletons(generate_envs, pi_epsilon_greedy_config,
                                workers_seeds, gamma, gamma_c,
                                workers_n_trajectories, None, workers_epsilons,
                                None, general["dictConfig"]))

        # Collect trajectories
        logger.info(
            "Collecting trajectories with {} workers...".format(cpu_processes))
        if cpu_processes == 1:
            results = [execute_policy_from_config(*workers_params[0])]
        else:
            with Pool(processes=cpu_processes) as pool:
                results = pool.starmap(execute_policy_from_config,
                                       workers_params)
        i_traj += sum([len(trajectories) for trajectories, _ in results])

        # Fill memory
        [
            rm.push(*sample) for trajectories, _ in results
            for trajectory in trajectories for sample in trajectory
        ]
        transitions_ftq, _ = datas_to_transitions(rm.memory, e, feature,
                                                  lambda_, normalize_reward)

        # Fit model
        logger.info(
            "[BATCH={}]---------------------------------------".format(batch))
        logger.info(
            "[BATCH={}][learning ftq pi greedy] #samples={} #traj={}".format(
                batch, len(transitions_ftq), i_traj))
        logger.info(
            "[BATCH={}]---------------------------------------".format(batch))
        ftq = build_fresh_ftq()
        ftq.reset(True)
        ftq.workspace = workspace / "batch={}".format(batch)
        makedirs(ftq.workspace)

        if isinstance(e, EnvGridWorld):

            for trajectories, _ in results:
                for traj in trajectories:
                    trajs.append(traj)

            w = World(e)
            w.draw_frame()
            w.draw_lattice()
            w.draw_cases()
            w.draw_source_trajectories(trajs)
            w.save((ftq.workspace / "bftq_on_2dworld_sources").as_posix())

        ftq.fit(transitions_ftq)

        # Save policy
        network_path = ftq.save_policy()
        os.system("cp {}/policy.pt {}/final_policy.pt".format(
            ftq.workspace, workspace))

        # Update greedy policy
        pi_epsilon_greedy_config["pi_greedy"] = {
            "__class__": repr(PytorchFittedPolicy),
            "feature_str": feature_str,
            "network_path": network_path,
            "device": ftq.device
        }
    if save_memory is not None:
        rm.save(workspace / save_memory["path"], save_memory["as_json"])
Example #14
0
def create_Q_histograms_for_actions(title,
                                    QQ,
                                    path,
                                    labels,
                                    mask_action=None,
                                    lims=(-1.1, 1.1)):
    makedirs(path)
    #
    # # set up style cycles

    if mask_action is None:
        mask_action = np.zeros(len(QQ[0]))
    labs = []
    for i, label in enumerate(labels):
        # labels[i] = label[0:6]
        if mask_action[i] != 1:
            # labs.append(label[0:6])
            labs.append(label)
    Nact = len(mask_action)
    values = []
    for act in range(Nact):
        if mask_action[act] != 1:
            value = []
            for i in range(len(QQ)):
                value.append(QQ[i][act])
            values.append(value)
    plt.clf()

    lims = update_lims(lims, values)
    edges = np.linspace(*lims, 200, endpoint=True)
    hist_func = partial(np.histogram, bins=edges)
    colors = plt.get_cmap('tab10').colors
    #['b', 'g', 'r', 'c', 'm', 'y', 'k']
    hatchs = ['/', '*', '+', '|']
    cols = []
    hats = []
    for i, lab in enumerate(labels):
        cols.append(colors[i % len(colors)])
        hats.append(hatchs[i % len(hatchs)])

    color_cycle = cycler(facecolor=cols)
    label_cycle = cycler('label', labs)
    hatch_cycle = cycler('hatch', hats)
    fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
    arts = stack_hist(ax,
                      values,
                      color_cycle + label_cycle + hatch_cycle,
                      hist_func=hist_func,
                      labels=labs)

    plt.grid(axis='y', alpha=0.75)

    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.title(title)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    # plt.legend()
    if not os.path.exists(path):
        os.mkdir(path)
    plt.savefig(path / title)
    plt.close()
Example #15
0
def main(betas_test, policy_path, generate_envs, feature_str, device,
         workspace, gamma, gamma_c, bftq_params, seed, N_trajs, path_results,
         general, **args):
    if not os.path.isabs(policy_path):
        policy_path = workspace / policy_path

    pi_config = {
        "__class__": repr(PytorchBudgetedFittedPolicy),
        "feature_str": feature_str,
        "network_path": policy_path,
        "betas_for_discretisation":
        eval(bftq_params["betas_for_discretisation"]),
        "device": device,
        "hull_options": general["hull_options"],
        "clamp_Qc": bftq_params["clamp_Qc"]
    }
    mock_env = envs_factory.generate_envs(**generate_envs)[0][0]
    makedirs(workspace / "trajs")

    makedirs(path_results)
    set_seed(seed)
    try:
        for beta in eval(betas_test):
            # Prepare workers
            cpu_processes = min(
                general["cpu"]["processes_when_linked_with_gpu"]
                or os.cpu_count(), N_trajs)
            workers_n_trajectories = near_split(N_trajs, cpu_processes)
            workers_seeds = np.random.randint(0, 10000, cpu_processes).tolist()
            workers_params = list(
                zip_with_singletons(
                    generate_envs, pi_config, workers_seeds, gamma, gamma_c,
                    workers_n_trajectories, beta, None,
                    "{}/beta={}.results".format(path_results,
                                                beta), general["dictConfig"]))
            logger.info("Collecting trajectories with {} workers...".format(
                cpu_processes))
            with Pool(cpu_processes) as pool:
                results = pool.starmap(execute_policy_from_config,
                                       workers_params)
                rez = np.concatenate([result for _, result in results], axis=0)

                trajs = []
                for t, _ in results:
                    trajs += t
            print("BFTQ({:.2f}) : {}".format(beta, format_results(rez)))

            if isinstance(mock_env, EnvGridWorld):
                from ncarrara.utils_rl.environments.gridworld.world import World
                w = World(mock_env)
                w.draw_frame()
                w.draw_lattice()
                w.draw_cases()
                w.draw_test_trajectories(trajs)
                pp = (workspace / "trajs" / "trajs_beta").as_posix()
                w.save(pp + "={:.2f}".format(beta))
        if isinstance(mock_env, EnvGridWorld):
            os.system("convert -delay 10 -loop 0 " + workspace.as_posix() +
                      "/trajs/" + "*.png " + workspace.as_posix() + "/out.gif")

    except FileNotFoundError as e:
        logger.warning("Could not load policy: {}".format(e))
Example #16
0
def main(generate_envs, feature_str, betas_for_exploration, gamma, gamma_c,
         bftq_params, bftq_net_params, N_trajs, workspace, seed, device,
         normalize_reward, trajs_by_ftq_batch, epsilon_decay, general, **args):
    # Prepare BFTQ
    envs, params = envs_factory.generate_envs(**generate_envs)
    e = envs[0]
    set_seed(seed, e)
    rm = Memory()
    feature = feature_factory(feature_str)

    def build_fresh_bftq():
        bftq = PytorchBudgetedFittedQ(
            device=device,
            workspace=workspace / "batch=0",
            actions_str=get_actions_str(e),
            policy_network=NetBFTQ(size_state=len(feature(e.reset(), e)),
                                   n_actions=e.action_space.n,
                                   **bftq_net_params),
            gamma=gamma,
            gamma_c=gamma_c,
            cpu_processes=general["cpu"]["processes"],
            env=e,
            split_batches=general["gpu"]["split_batches"],
            hull_options=general["hull_options"],
            **bftq_params)
        return bftq

    # Prepare learning
    i_traj = 0
    decays = math_utils.epsilon_decay(**epsilon_decay,
                                      N=N_trajs,
                                      savepath=workspace)
    betas_for_exploration = np.array(eval(betas_for_exploration))
    memory_by_batch = [get_current_memory()]
    batch_sizes = near_split(N_trajs, size_bins=trajs_by_ftq_batch)
    pi_epsilon_greedy_config = {
        "__class__": repr(EpsilonGreedyPolicy),
        "pi_greedy": {
            "__class__": repr(RandomBudgetedPolicy)
        },
        "pi_random": {
            "__class__": repr(RandomBudgetedPolicy)
        },
        "epsilon": decays[0],
        "hull_options": general["hull_options"],
        "clamp_Qc": bftq_params["clamp_Qc"]
    }

    # Main loop
    trajs = []
    for batch, batch_size in enumerate(batch_sizes):
        # Prepare workers
        cpu_processes = min(
            general["cpu"]["processes_when_linked_with_gpu"] or os.cpu_count(),
            batch_size)
        workers_n_trajectories = near_split(batch_size, cpu_processes)
        workers_start = np.cumsum(workers_n_trajectories)
        workers_traj_indexes = [
            np.arange(*times) for times in zip(
                np.insert(workers_start[:-1], 0, 0), workers_start)
        ]
        if betas_for_exploration.size:
            workers_betas = [
                betas_for_exploration.take(indexes, mode='wrap')
                for indexes in workers_traj_indexes
            ]
        else:
            workers_betas = [
                np.random.random(indexes.size)
                for indexes in workers_traj_indexes
            ]
        workers_seeds = np.random.randint(0, 10000, cpu_processes).tolist()
        workers_epsilons = [
            decays[i_traj + indexes] for indexes in workers_traj_indexes
        ]
        workers_params = list(
            zip_with_singletons(generate_envs, pi_epsilon_greedy_config,
                                workers_seeds, gamma, gamma_c,
                                workers_n_trajectories, workers_betas,
                                workers_epsilons, None, general["dictConfig"]))

        # Collect trajectories
        logger.info(
            "Collecting trajectories with {} workers...".format(cpu_processes))
        if cpu_processes == 1:
            results = []
            for params in workers_params:
                results.append(execute_policy_from_config(*params))
        else:
            with Pool(processes=cpu_processes) as pool:
                results = pool.starmap(execute_policy_from_config,
                                       workers_params)
        i_traj += sum([len(trajectories) for trajectories, _ in results])

        # Fill memory
        [
            rm.push(*sample) for trajectories, _ in results
            for trajectory in trajectories for sample in trajectory
        ]

        transitions_ftq, transition_bftq = datas_to_transitions(
            rm.memory, e, feature, 0, normalize_reward)

        # Fit model
        logger.info(
            "[BATCH={}]---------------------------------------".format(batch))
        logger.info(
            "[BATCH={}][learning bftq pi greedy] #samples={} #traj={}".format(
                batch, len(transition_bftq), i_traj))
        logger.info(
            "[BATCH={}]---------------------------------------".format(batch))
        bftq = build_fresh_bftq()
        bftq.reset(True)
        bftq.workspace = workspace / "batch={}".format(batch)
        makedirs(bftq.workspace)
        if isinstance(e, EnvGridWorld):
            for trajectories, _ in results:
                for traj in trajectories:
                    trajs.append(traj)

            w = World(e)
            w.draw_frame()
            w.draw_lattice()
            w.draw_cases()
            w.draw_source_trajectories(trajs)
            w.save((bftq.workspace / "bftq_on_2dworld_sources").as_posix())
        q = bftq.fit(transition_bftq)

        # Save policy
        network_path = bftq.save_policy()
        os.system("cp {}/policy.pt {}/policy.pt".format(
            bftq.workspace, workspace))

        # Save memory
        save_memory(bftq, memory_by_batch, by_batch=False)

        # Update greedy policy
        pi_epsilon_greedy_config["pi_greedy"] = {
            "__class__": repr(PytorchBudgetedFittedPolicy),
            "feature_str": feature_str,
            "network_path": network_path,
            "betas_for_discretisation": bftq.betas_for_discretisation,
            "device": bftq.device,
            "hull_options": general["hull_options"],
            "clamp_Qc": bftq_params["clamp_Qc"]
        }

        if isinstance(e, EnvGridWorld):

            def pi(state, beta):
                import torch
                from ncarrara.budgeted_rl.bftq.pytorch_budgeted_fittedq import convex_hull, \
                    optimal_pia_pib
                with torch.no_grad():
                    hull = convex_hull(s=torch.tensor([state],
                                                      device=device,
                                                      dtype=torch.float32),
                                       Q=q,
                                       action_mask=np.zeros(e.action_space.n),
                                       id="run_" + str(state),
                                       disp=False,
                                       betas=bftq.betas_for_discretisation,
                                       device=device,
                                       hull_options=general["hull_options"],
                                       clamp_Qc=bftq_params["clamp_Qc"])
                    opt, _ = optimal_pia_pib(beta=beta,
                                             hull=hull,
                                             statistic={})
                return opt

            def qr(state, a, beta):
                import torch
                s = torch.tensor([[state]], device=device)
                b = torch.tensor([[[beta]]], device=device)
                sb = torch.cat((s, b), dim=2)
                return q(sb).squeeze()[a]

            def qc(state, a, beta):
                import torch
                s = torch.tensor([[state]], device=device)
                b = torch.tensor([[[beta]]], device=device)
                sb = torch.cat((s, b), dim=2)
                return q(sb).squeeze()[e.action_space.n + a]

            w = World(e, bftq.betas_for_discretisation)
            w.draw_frame()
            w.draw_lattice()
            w.draw_cases()
            w.draw_policy_bftq(pi, qr, qc, bftq.betas_for_discretisation)
            w.save((bftq.workspace / "bftq_on_2dworld").as_posix())

    save_memory(bftq, memory_by_batch, by_batch=True)
Example #17
0
def main(policy_path, generate_envs, feature_str, device, workspace, bftq_params, seed, general,
         betas_test, N_trajs, gamma, gamma_c, bftq_net_params, **args):
    if not os.path.isabs(policy_path):
        policy_path = workspace / policy_path

    env = envs_factory.generate_envs(**generate_envs)[0][0]
    feature = feature_factory(feature_str)

    bftq = PytorchBudgetedFittedQ(
        device=device,
        workspace=workspace,
        actions_str=get_actions_str(env),
        policy_network=NetBFTQ(size_state=len(feature(env.reset(), env)), n_actions=env.action_space.n,
                               **bftq_net_params),
        gamma=gamma,
        gamma_c=gamma_c,
        cpu_processes=general["cpu"]["processes"],
        env=env,
        hull_options=general["hull_options"],
        **bftq_params)
    bftq.reset(True)

    pi_config = {
        "__class__": repr(PytorchBudgetedFittedPolicy),
        "feature_str": feature_str,
        "network_path": policy_path,
        "betas_for_discretisation": eval(bftq_params["betas_for_discretisation"]),
        "device": device,
        "hull_options": general["hull_options"],
        "clamp_Qc": bftq_params["clamp_Qc"],
        "env": env
    }
    pi = policy_factory(pi_config)

    # Iterate over betas
    for beta in eval(betas_test):
        logger.info("Rendering with beta={}".format(beta))
        set_seed(seed, env)
        for traj in range(N_trajs):
            done = False
            pi.reset()
            info_env = {}
            info_pi = {"beta": beta}
            t = 0

            # Make a workspace for trajectories
            traj_workspace = workspace / "trajs" / "beta={}".format(beta) / "traj={}".format(traj)
            makedirs(traj_workspace)
            bftq.workspace = traj_workspace
            monitor = MonitorV2(env, traj_workspace, add_subdirectory=False)
            obs = monitor.reset()

            # Run trajectory
            while not done:
                action_mask = get_action_mask(env)
                info_pi = merge_two_dicts(info_pi, info_env)
                bftq.draw_Qr_and_Qc(obs, pi.network, "render_t={}".format(t), show=False)
                a, _, info_pi = pi.execute(obs, action_mask, info_pi)
                render(env, workspace, t, a)
                obs, _, done, info_env = monitor.step(a)
                t += 1
            monitor.close()