コード例 #1
0
ファイル: custom_loss.py プロジェクト: zqxyz73/ray
 def _build_layers_v2(self, input_dict, num_outputs, options):
     self.obs_in = input_dict["obs"]
     with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
         self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
                                            self.action_space, num_outputs,
                                            options)
     return self.fcnet.outputs, self.fcnet.last_layer
コード例 #2
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
                                           self.action_space, num_outputs,
                                           options)
        feature_out = tf.py_function(self.forward_eager,
                                     [self.fcnet.last_layer], tf.float32)

        with tf.control_dependencies([feature_out]):
            return tf.identity(self.fcnet.outputs), feature_out
コード例 #3
0
    def _get_model(input_dict, obs_space, action_space, num_outputs, options,
                   state_in, seq_lens):
        if options.get("custom_model"):
            model = options["custom_model"]
            logger.debug("Using custom model {}".format(model))
            return _global_registry.get(RLLIB_MODEL, model)(input_dict,
                                                            obs_space,
                                                            action_space,
                                                            num_outputs,
                                                            options,
                                                            state_in=state_in,
                                                            seq_lens=seq_lens)

        obs_rank = len(input_dict["obs"].shape) - 1  # drops batch dim

        if obs_rank > 2:
            return VisionNetwork(input_dict, obs_space, action_space,
                                 num_outputs, options)

        return FullyConnectedNetwork(input_dict, obs_space, action_space,
                                     num_outputs, options)