def reset(self, reset_weight=True): self.memory.reset() if reset_weight: self.policy_net.reset() if self.tranfer_module is not None and self.tranfer_module.is_q_transfering( ): # print(self.policy_net) self.policy_net.set_Q_source(self.tranfer_module.get_Q_source(), self.tranfer_module.get_best_error()) self.optimizer = optimizer_factory(self.optimizer_type, self.policy_net.parameters(), self.lr, self.weight_decay) self.i_episode = 0 self.no_need_for_transfer_anymore = False self.target_net = copy.deepcopy(self.policy_net) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() if self.tranfer_module is not None and self.tranfer_module.is_q_transfering( ): self.target_net.set_Q_source(self.policy_net.Q_source, self.tranfer_module.get_best_error())
def reset(self, reset_weight=True): torch.cuda.empty_cache() if reset_weight: self.reset_network() self.optimizer = optimizer_factory(self.optimizer_type, self._policy_network.parameters(), self.learning_rate, self.weight_decay) self._id_ftq_epoch = None
def main(loss_function_str, optimizer_str, weight_decay, learning_rate, normalize, autoencoder_size, n_epochs, feature_autoencoder_info, workspace, device, type_ae="AEA", N_actions=None, writer=None): import torch loss_function = loss_fonction_factory(loss_function_str) makedirs(workspace) feature = build_feature_autoencoder(feature_autoencoder_info) min_n, max_n = autoencoder_size all_transitions = utils.read_samples_for_ae(workspace / "samples", feature, N_actions) autoencoders = [ AutoEncoder(n_in=transitions.X.shape[1], n_out=transitions.X.shape[1] * (N_actions if type_ae == "AEA" else 1), min_n=min_n, max_n=max_n, device=device) for transitions in all_transitions ] path_auto_encoders = workspace / "ae" makedirs(path_auto_encoders) print("learning_rate", learning_rate) print("optimizer_str", optimizer_str) print("weight_decay", weight_decay) # exit() for ienv, transitions in enumerate(all_transitions): autoencoders[ienv].reset() optimizer = optimizer_factory(optimizer_str, autoencoders[ienv].parameters(), lr=learning_rate, weight_decay=weight_decay) # for x,y in zip(transitions.X,transitions.A): # print(x,"->",y) autoencoders[ienv].fit(transitions, size_minibatch=all_transitions[ienv].X.shape[0], n_epochs=n_epochs, optimizer=optimizer, normalize=normalize, stop_loss=0.01, loss_function=loss_function, writer=writer) path_autoencoder = path_auto_encoders / "{}.pt".format(ienv) logger.info("saving autoencoder at {}".format(path_autoencoder)) torch.save(autoencoders[ienv], path_autoencoder)
def reset(self, reset_weight=True): self.i_episode = 0 self.memory = Memory() if reset_weight: self.full_net.reset() # target net of the main and greedy classic net self.full_target_net = copy.deepcopy(self.full_net) self.full_target_net.load_state_dict(self.full_net.state_dict()) self.full_target_net.eval() self.parameters_full_net = self.full_net.parameters() self.optimizer_full_net = optimizer_factory(self.optimizer_type, self.parameters_full_net, self.lr, self.weight_decay) if self.tm is not None: self.tm.reset() self.best_net = self.tm.get_best_Q_source() self.previous_diff = -np.inf self.previous_idx_best_source = self.tm.idx_best_fit logging.info( "[INITIAL][i_episode{}] Using {} source cstd={:.2f} {} Q function" .format(self.i_episode, Color.BOLD, self.tm.best_source_params()["cstd"], Color.END)) if self.ratio_learn_test > 0: # net in order to eval bellman residu on test batch self.memory_partial_learn = Memory() self.memory_partial_test = Memory() self.partial_net = copy.deepcopy(self.full_net) self.partial_net.load_state_dict(self.full_net.state_dict()) self.partial_net.eval() self.partial_target_net = copy.deepcopy(self.partial_net) self.partial_net.load_state_dict(self.partial_net.state_dict()) self.partial_net.eval() self.parameters_partial_net = self.partial_net.parameters() self.optimizer_partial_net = optimizer_factory( self.optimizer_type, self.parameters_partial_net, self.lr, self.weight_decay) else: self.best_net = self.full_net
def reset(self, reset_weight=True): self.memory.reset() if reset_weight: self._policy_network.reset() self.optimizer = optimizer_factory(self.optimizer_type, self._policy_network.parameters(), self.learning_rate, self.weight_decay) self._id_ftq_epoch = None self._non_final_mask = None self._non_final_next_states = None self._state_batch = None self._action_batch = None self._reward_batch = None