Exemplo n.º 1
0
 def test_huber_loss_config(self):
     config_dict = {
         'input_layer': 'images',
         'output_layer': 'relu_1',
         'clip': 0.1,
         'weights': 1.0,
         'name': 'l',
         'collect': False
     }
     config = HuberLossConfig.from_dict(config_dict)
     self.assert_equal_losses(config.to_dict(), config_dict)
Exemplo n.º 2
0
 def model_fn(features, labels, mode):
     model = plx.models.DDQNModel(
         mode,
         graph_fn=graph_fn,
         loss=HuberLossConfig(),
         num_states=env.num_states,
         num_actions=env.num_actions,
         optimizer=SGDConfig(learning_rate=0.01),
         exploration_config=DecayExplorationConfig(),
         target_update_frequency=10,
         summaries='all')
     return model(features, labels)
Exemplo n.º 3
0
    def __init__(self,
                 mode,
                 graph_fn,
                 num_states,
                 num_actions,
                 loss=None,
                 optimizer=None,
                 metrics=None,
                 discount=0.97,
                 exploration_config=None,
                 use_target_graph=True,
                 target_update_frequency=5,
                 is_continuous=False,
                 dueling='mean',
                 use_expert_demo=False,
                 summaries='all',
                 clip_gradients=0.5,
                 clip_embed_gradients=0.1,
                 name="Model"):
        self.num_states = num_states
        self.num_actions = num_actions
        self.exploration_config = exploration_config
        self.discount = discount
        self.use_target_graph = use_target_graph
        self.target_update_frequency = target_update_frequency
        self.is_continuous = is_continuous
        self.dueling = dueling
        self.use_expert_demo = use_expert_demo
        loss = loss or HuberLossConfig()

        super(BaseQModel,
              self).__init__(mode=mode,
                             name=name,
                             model_type=self.Types.RL,
                             graph_fn=graph_fn,
                             loss=loss,
                             optimizer=optimizer,
                             metrics=metrics,
                             summaries=summaries,
                             clip_gradients=clip_gradients,
                             clip_embed_gradients=clip_embed_gradients)

        self._train_graph = None
        self._target_graph = None