Exemple #1
0
    def reset(self, reset_weight=True):
        self.memory.reset()
        if reset_weight:
            self.policy_net.reset()

        if self.tranfer_module is not None and self.tranfer_module.is_q_transfering(
        ):
            # print(self.policy_net)
            self.policy_net.set_Q_source(self.tranfer_module.get_Q_source(),
                                         self.tranfer_module.get_best_error())

        self.optimizer = optimizer_factory(self.optimizer_type,
                                           self.policy_net.parameters(),
                                           self.lr, self.weight_decay)

        self.i_episode = 0
        self.no_need_for_transfer_anymore = False

        self.target_net = copy.deepcopy(self.policy_net)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        if self.tranfer_module is not None and self.tranfer_module.is_q_transfering(
        ):
            self.target_net.set_Q_source(self.policy_net.Q_source,
                                         self.tranfer_module.get_best_error())
Exemple #2
0
 def reset(self, reset_weight=True):
     torch.cuda.empty_cache()
     if reset_weight:
         self.reset_network()
     self.optimizer = optimizer_factory(self.optimizer_type,
                                        self._policy_network.parameters(),
                                        self.learning_rate,
                                        self.weight_decay)
     self._id_ftq_epoch = None
Exemple #3
0
def main(loss_function_str,
         optimizer_str,
         weight_decay,
         learning_rate,
         normalize,
         autoencoder_size,
         n_epochs,
         feature_autoencoder_info,
         workspace,
         device,
         type_ae="AEA",
         N_actions=None,
         writer=None):
    import torch
    loss_function = loss_fonction_factory(loss_function_str)
    makedirs(workspace)
    feature = build_feature_autoencoder(feature_autoencoder_info)
    min_n, max_n = autoencoder_size

    all_transitions = utils.read_samples_for_ae(workspace / "samples", feature,
                                                N_actions)

    autoencoders = [
        AutoEncoder(n_in=transitions.X.shape[1],
                    n_out=transitions.X.shape[1] *
                    (N_actions if type_ae == "AEA" else 1),
                    min_n=min_n,
                    max_n=max_n,
                    device=device) for transitions in all_transitions
    ]

    path_auto_encoders = workspace / "ae"
    makedirs(path_auto_encoders)
    print("learning_rate", learning_rate)
    print("optimizer_str", optimizer_str)
    print("weight_decay", weight_decay)
    # exit()
    for ienv, transitions in enumerate(all_transitions):
        autoencoders[ienv].reset()
        optimizer = optimizer_factory(optimizer_str,
                                      autoencoders[ienv].parameters(),
                                      lr=learning_rate,
                                      weight_decay=weight_decay)
        # for x,y in zip(transitions.X,transitions.A):
        #     print(x,"->",y)
        autoencoders[ienv].fit(transitions,
                               size_minibatch=all_transitions[ienv].X.shape[0],
                               n_epochs=n_epochs,
                               optimizer=optimizer,
                               normalize=normalize,
                               stop_loss=0.01,
                               loss_function=loss_function,
                               writer=writer)

        path_autoencoder = path_auto_encoders / "{}.pt".format(ienv)
        logger.info("saving autoencoder at {}".format(path_autoencoder))
        torch.save(autoencoders[ienv], path_autoencoder)
Exemple #4
0
    def reset(self, reset_weight=True):
        self.i_episode = 0
        self.memory = Memory()

        if reset_weight:
            self.full_net.reset()

        # target net of the main and greedy classic net
        self.full_target_net = copy.deepcopy(self.full_net)
        self.full_target_net.load_state_dict(self.full_net.state_dict())
        self.full_target_net.eval()
        self.parameters_full_net = self.full_net.parameters()
        self.optimizer_full_net = optimizer_factory(self.optimizer_type,
                                                    self.parameters_full_net,
                                                    self.lr, self.weight_decay)

        if self.tm is not None:
            self.tm.reset()
            self.best_net = self.tm.get_best_Q_source()
            self.previous_diff = -np.inf
            self.previous_idx_best_source = self.tm.idx_best_fit
            logging.info(
                "[INITIAL][i_episode{}] Using {} source cstd={:.2f} {} Q function"
                .format(self.i_episode, Color.BOLD,
                        self.tm.best_source_params()["cstd"], Color.END))
            if self.ratio_learn_test > 0:
                # net in order to eval bellman residu on test batch
                self.memory_partial_learn = Memory()
                self.memory_partial_test = Memory()
                self.partial_net = copy.deepcopy(self.full_net)
                self.partial_net.load_state_dict(self.full_net.state_dict())
                self.partial_net.eval()
                self.partial_target_net = copy.deepcopy(self.partial_net)
                self.partial_net.load_state_dict(self.partial_net.state_dict())
                self.partial_net.eval()
                self.parameters_partial_net = self.partial_net.parameters()
                self.optimizer_partial_net = optimizer_factory(
                    self.optimizer_type, self.parameters_partial_net, self.lr,
                    self.weight_decay)
        else:
            self.best_net = self.full_net
Exemple #5
0
 def reset(self, reset_weight=True):
     self.memory.reset()
     if reset_weight:
         self._policy_network.reset()
     self.optimizer = optimizer_factory(self.optimizer_type,
                                        self._policy_network.parameters(),
                                        self.learning_rate,
                                        self.weight_decay)
     self._id_ftq_epoch = None
     self._non_final_mask = None
     self._non_final_next_states = None
     self._state_batch = None
     self._action_batch = None
     self._reward_batch = None