Ejemplo n.º 1
0
 def reset(self, reset_weight=True):
     torch.cuda.empty_cache()
     if reset_weight:
         self.reset_network()
     self.optimizer = optimizer_factory(self.config["optimizer"]["type"],
                                        self._value_network.parameters(),
                                        self.config["optimizer"]["learning_rate"],
                                        self.config["optimizer"]["weight_decay"])
     self.epoch = 0
Ejemplo n.º 2
0
 def __init__(self, env, config=None):
     super(DQNAgent, self).__init__(env, config)
     self.value_net = model_factory(self.config["model"])
     self.target_net = model_factory(self.config["model"])
     self.target_net.load_state_dict(self.value_net.state_dict())
     self.target_net.eval()
     self.device = choose_device(self.config["device"])
     self.value_net.to(self.device)
     self.target_net.to(self.device)
     self.loss_function = loss_function_factory(
         self.config["loss_function"])
     self.optimizer = optimizer_factory(self.config["optimizer"]["type"],
                                        self.value_net.parameters(),
                                        **self.config["optimizer"])
     self.steps = 0
Ejemplo n.º 3
0
 def __init__(self, env, config=None):
     super(DQNAgent, self).__init__(env, config)
     size_model_config(self.env, self.config["model"])
     self.value_net = model_factory(self.config["model"])
     self.target_net = model_factory(self.config["model"])
     self.target_net.load_state_dict(self.value_net.state_dict())
     self.target_net.eval()
     logger.debug("Number of trainable parameters: {}".format(trainable_parameters(self.value_net)))
     self.device = choose_device(self.config["device"])
     self.value_net.to(self.device)
     self.target_net.to(self.device)
     self.loss_function = loss_function_factory(self.config["loss_function"])
     self.optimizer = optimizer_factory(self.config["optimizer"]["type"],
                                        self.value_net.parameters(),
                                        **self.config["optimizer"])
     self.steps = 0
Ejemplo n.º 4
0
 def initialize_model(self):
     self.value_net.reset()
     self.optimizer = optimizer_factory(self.config["optimizer"]["type"],
                                        self.value_net.parameters(),
                                        **self.config["optimizer"])