예제 #1
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        # TODO: (sven) Support Dicts as well.
        self.original_space = obs_space.original_space if \
            hasattr(obs_space, "original_space") else obs_space
        assert isinstance(self.original_space, (Tuple)), \
            "`obs_space.original_space` must be Tuple!"

        nn.Module.__init__(self)
        TorchModelV2.__init__(self, self.original_space, action_space,
                              num_outputs, model_config, name)

        # Atari type CNNs or IMPALA type CNNs (with residual layers)?
        # self.cnn_type = self.model_config["custom_model_config"].get(
        #     "conv_type", "atari")

        # Build the CNN(s) given obs_space's image components.
        self.cnns = {}
        self.one_hot = {}
        self.flatten = {}
        concat_size = 0
        for i, component in enumerate(self.original_space):
            # Image space.
            if len(component.shape) == 3:
                config = {
                    "conv_filters":
                    model_config.get("conv_filters",
                                     get_filter_config(component.shape)),
                    "conv_activation":
                    model_config.get("conv_activation"),
                    "post_fcnet_hiddens": [],
                }
                # if self.cnn_type == "atari":
                cnn = ModelCatalog.get_model_v2(component,
                                                action_space,
                                                num_outputs=None,
                                                model_config=config,
                                                framework="torch",
                                                name="cnn_{}".format(i))
                # TODO (sven): add IMPALA-style option.
                # else:
                #    cnn = TorchImpalaVisionNet(
                #        component,
                #        action_space,
                #        num_outputs=None,
                #        model_config=config,
                #        name="cnn_{}".format(i))

                concat_size += cnn.num_outputs
                self.cnns[i] = cnn
                self.add_module("cnn_{}".format(i), cnn)
            # Discrete inputs -> One-hot encode.
            elif isinstance(component, Discrete):
                self.one_hot[i] = True
                concat_size += component.n
            # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
            # Everything else (1D Box).
            else:
                self.flatten[i] = int(np.product(component.shape))
                concat_size += self.flatten[i]

        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation",
                                                 "relu")
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(Box(
            float("-inf"),
            float("inf"),
            shape=(concat_size, ),
            dtype=np.float32),
                                                       self.action_space,
                                                       None,
                                                       post_fc_stack_config,
                                                       framework="torch",
                                                       name="post_fc_stack")

        # Actions and value heads.
        self.logits_layer = None
        self.value_layer = None
        self._value_out = None

        if num_outputs:
            # Action-distribution head.
            self.logits_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=num_outputs,
                activation_fn=None,
            )
            # Create the value branch model.
            self.value_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=1,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01))
        else:
            self.num_outputs = concat_size
예제 #2
0
파일: fcnet.py 프로젝트: tuyulers5/jav44
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        hiddens = model_config.get("fcnet_hiddens", []) + \
            model_config.get("post_fcnet_hiddens", [])
        activation = model_config.get("fcnet_activation")
        if not model_config.get("fcnet_hiddens", []):
            activation = model_config.get("post_fcnet_activation")
        no_final_linear = model_config.get("no_final_linear")
        self.vf_share_layers = model_config.get("vf_share_layers")
        self.free_log_std = model_config.get("free_log_std")

        # Generate free-floating bias variables for the second half of
        # the outputs.
        if self.free_log_std:
            assert num_outputs % 2 == 0, (
                "num_outputs must be divisible by two", num_outputs)
            num_outputs = num_outputs // 2

        layers = []
        prev_layer_size = int(np.product(obs_space.shape))
        self._logits = None

        # Create layers 0 to second-last.
        for size in hiddens[:-1]:
            layers.append(
                SlimFC(in_size=prev_layer_size,
                       out_size=size,
                       initializer=normc_initializer(1.0),
                       activation_fn=activation))
            prev_layer_size = size

        # The last layer is adjusted to be of size num_outputs, but it's a
        # layer with activation.
        if no_final_linear and num_outputs:
            layers.append(
                SlimFC(in_size=prev_layer_size,
                       out_size=num_outputs,
                       initializer=normc_initializer(1.0),
                       activation_fn=activation))
            prev_layer_size = num_outputs
        # Finish the layers with the provided sizes (`hiddens`), plus -
        # iff num_outputs > 0 - a last linear layer of size num_outputs.
        else:
            if len(hiddens) > 0:
                layers.append(
                    SlimFC(in_size=prev_layer_size,
                           out_size=hiddens[-1],
                           initializer=normc_initializer(1.0),
                           activation_fn=activation))
                prev_layer_size = hiddens[-1]
            if num_outputs:
                self._logits = SlimFC(in_size=prev_layer_size,
                                      out_size=num_outputs,
                                      initializer=normc_initializer(0.01),
                                      activation_fn=None)
            else:
                self.num_outputs = ([int(np.product(obs_space.shape))] +
                                    hiddens[-1:])[-1]

        # Layer to add the log std vars to the state-dependent means.
        if self.free_log_std and self._logits:
            self._append_free_log_std = AppendBiasLayer(num_outputs)

        self._hidden_layers = nn.Sequential(*layers)

        self._value_branch_separate = None
        if not self.vf_share_layers:
            # Build a parallel set of hidden layers for the value net.
            prev_vf_layer_size = int(np.product(obs_space.shape))
            vf_layers = []
            for size in hiddens:
                vf_layers.append(
                    SlimFC(in_size=prev_vf_layer_size,
                           out_size=size,
                           activation_fn=activation,
                           initializer=normc_initializer(1.0)))
                prev_vf_layer_size = size
            self._value_branch_separate = nn.Sequential(*vf_layers)

        self._value_branch = SlimFC(in_size=prev_layer_size,
                                    out_size=1,
                                    initializer=normc_initializer(0.01),
                                    activation_fn=None)
        # Holds the current "base" output (before logits layer).
        self._features = None
        # Holds the last input, in case value branch is separate.
        self._last_flat_in = None
예제 #3
0
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):
        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        assert len(filters) > 0,\
            "Must provide at least 1 entry in `conv_filters`!"
        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None

        layers = []
        (w, h, in_channels) = obs_space.shape
        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer is a Conv2D and uses num_outputs.
        if no_final_linear and num_outputs:
            layers.append(
                SlimConv2d(
                    in_channels,
                    num_outputs,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))
            out_channels = num_outputs
        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride)
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                self._logits = SlimConv2d(out_channels,
                                          num_outputs, [1, 1],
                                          1,
                                          padding,
                                          activation_fn=None)
            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())
                self.num_outputs = out_channels

        self._convs = nn.Sequential(*layers)

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(out_channels,
                                        1,
                                        initializer=normc_initializer(0.01),
                                        activation_fn=None)
        else:
            vf_layers = []
            (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel,
                                                 [stride, stride])
                vf_layers.append(
                    SlimConv2d(in_channels,
                               out_channels,
                               kernel,
                               stride,
                               padding,
                               activation_fn=activation))
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           None,
                           activation_fn=activation))

            vf_layers.append(
                SlimConv2d(in_channels=out_channels,
                           out_channels=1,
                           kernel=1,
                           stride=1,
                           padding=None,
                           activation_fn=None))
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None
예제 #4
0
    def __init__(
        self, obs_space, action_space, num_outputs, model_config, name, **kwargs
    ):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)
        model_config = merge_dicts(BASELINE_CONFIG, model_config["custom_model_config"])
        # new way to get model config directly from keyword arguments
        model_config = merge_dicts(model_config, kwargs)

        state_embed_size = model_config["state_embed_size"]
        self.use_rnn = model_config["use_rnn"]
        rnn_type = model_config["rnn_type"]
        self.use_prev_action_reward = model_config["use_prev_action_reward"]

        action_net_kwargs = model_config["action_net_kwargs"]
        if isinstance(action_space, gym.spaces.Discrete):
            action_net_kwargs.update({"discrete": True, "n": action_space.n})
            self.discrete = True
        else:
            self.discrete = False

        def get_factory(network_name):
            base_module = importlib.import_module(
                f"minerl_rllib.models.torch.{network_name}"
            )
            return getattr(base_module, model_config[f"{network_name}_net"])

        self._pov_network, pov_embed_size = get_factory("pov")(
            **model_config["pov_net_kwargs"]
        )
        self._vector_network, vector_embed_size = get_factory("vector")(
            **model_config["vector_net_kwargs"]
        )
        state_input_size = pov_embed_size + vector_embed_size
        if self.use_prev_action_reward:
            self._action_network, action_embed_size = get_factory("action")(
                **action_net_kwargs
            )
            self._reward_network, reward_embed_size = get_factory("reward")(
                **model_config["reward_net_kwargs"]
            )
            state_input_size += action_embed_size + reward_embed_size

        rnn_config = model_config.get("rnn_config")
        if self.use_rnn:
            state_embed_size = rnn_config["hidden_size"]
            if rnn_type == "lstm":
                self._rnn = LSTMBaseline(state_input_size, **rnn_config)
            elif rnn_type == "gru":
                self._rnn = GRUBaseline(state_input_size, **rnn_config)
            else:
                raise NotImplementedError
        else:
            self._state_network = nn.Sequential(
                nn.Linear(state_input_size, state_embed_size),
                nn.ELU(),
            )
        self._value_head = nn.Sequential(
            nn.Linear(state_embed_size, 1),
        )
        self._policy_head = nn.Sequential(
            nn.Linear(state_embed_size, num_outputs),
        )
예제 #5
0
파일: model.py 프로젝트: 51616/ray
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.options = options = model_config['custom_model_config']
        self.z_dist = D.Uniform(options['z_range'][0], options['z_range'][1])
        self.z_dim = options['z_dim']
        hidden_dim = options['hidden_dim']
        num_hiddens = options['num_hiddens']

        obs_shape = obs_space.shape
        if isinstance(options.get('orig_obs_space', None), spaces.Dict):
            obs_shape = options['orig_obs_space']['obs'].shape

        input_dim = int(np.prod(obs_shape)) + self.z_dim
        print(f'input dim: {input_dim}')
        activation_fn = nn.ELU
        initer = lambda w: nn.init.xavier_uniform_(w, 1.0)

        if options['share_hidden']:
            hiddens = [
                SlimFC(input_dim,
                       hidden_dim,
                       activation_fn=activation_fn,
                       initializer=activation_fn),
                nn.LayerNorm(hidden_dim)
            ]
            for _ in range(num_hiddens - 1):
                hiddens.append(
                    SlimFC(hidden_dim,
                           hidden_dim,
                           activation_fn=activation_fn,
                           initializer=initer))

                self.hiddens = nn.Sequential(*hiddens)
                self.logits = SlimFC(hidden_dim + self.z_dim,
                                     num_outputs,
                                     initializer=initer)
                self.values = SlimFC(hidden_dim + self.z_dim,
                                     1,
                                     initializer=initer)
        else:
            hiddens = [
                SlimFC(input_dim,
                       hidden_dim,
                       activation_fn=activation_fn,
                       initializer=initer),
                nn.LayerNorm(hidden_dim)
            ]
            for _ in range(num_hiddens - 1):
                hiddens.append(
                    SlimFC(hidden_dim,
                           hidden_dim,
                           activation_fn=activation_fn,
                           initializer=initer))
            hiddens.append(SlimFC(hidden_dim, num_outputs, initializer=initer))
            self.logits = nn.Sequential(*hiddens)

            hiddens = [
                SlimFC(input_dim,
                       hidden_dim,
                       activation_fn=activation_fn,
                       initializer=initer),
                nn.LayerNorm(hidden_dim)
            ]
            for _ in range(num_hiddens - 1):
                hiddens.append(
                    SlimFC(hidden_dim,
                           hidden_dim,
                           activation_fn=activation_fn,
                           initializer=initer))
            hiddens.append(SlimFC(hidden_dim, 1, initializer=initer))
            self.values = nn.Sequential(*hiddens)
예제 #6
0
    def __init__(self, spec_tree, obs_space, action_space):
        fcnet_hiddens = spec_tree['fcnet_hiddens']
        fcnet_activation = spec_tree['fcnet_activation']
        vf_share_layers = spec_tree['vf_share_layers']
        model_config = {'fcnet_hiddens': fcnet_hiddens,
                        'fcnet_activation': fcnet_activation,
                        'vf_share_layers': vf_share_layers}
        name = 'MLP'
        num_outputs = action_space.n
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        # import pdb; pdb.set_trace()
        assert model_config['vf_share_layers'] is False, 'Separate the value and policy branches for simplicity'
        layers = []
        input_layer = nn.Linear(int(np.product(obs_space.shape)), fcnet_hiddens[0])
        if fcnet_activation == 'tanh':
            activation_layer = nn.Tanh()
        elif fcnet_activation == 'relu':
            activation_layer = nn.ReLU()
        
        layers.append(input_layer)
        layers.append(activation_layer)

        for i in range(len(fcnet_hiddens)-1):
            hidden_layer = nn.Linear(fcnet_hiddens[i], fcnet_hiddens[i+1])
            layers.append(hidden_layer)
            layers.append(activation_layer)
        output_layer = nn.Linear(fcnet_hiddens[-1], num_outputs)
        layers.append(output_layer)
        
        self.policy_network = nn.Sequential(*layers)

        value_layers = []
        input_layer = nn.Linear(int(np.product(obs_space.shape)), fcnet_hiddens[0])
        value_layers.append(input_layer)
        value_layers.append(activation_layer)
        for i in range(len(fcnet_hiddens)-1):
            hidden_layer = nn.Linear(fcnet_hiddens[i], fcnet_hiddens[i+1])
            value_layers.append(hidden_layer)
            value_layers.append(activation_layer)
        output_layer = nn.Linear(fcnet_hiddens[-1], 1)
        value_layers.append(output_layer)

        self.value_network = nn.Sequential(*value_layers)

        # self.policy_network = nn.Sequential(
        #                         nn.Linear(int(np.product(obs_space.shape)), layer_size),
        #                         nn.Tanh(),
        #                         nn.Linear(layer_size,layer_size),
        #                         nn.Tanh(),
        #                         nn.Linear(layer_size, num_outputs)
        #                         )
        # self.value_network = nn.Sequential(
        #                         nn.Linear(int(np.product(obs_space.shape)), 64),
        #                         nn.Tanh(),
        #                         nn.Linear(64,64),
        #                         nn.Tanh(),
        #                         nn.Linear(64, 1)
        #                         )
        self._last_flat_in = None                                
예제 #7
0
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 actor_hidden_activation="relu",
                 actor_hiddens=(256, 256),
                 critic_hidden_activation="relu",
                 critic_hiddens=(256, 256),
                 twin_q=False,
                 initial_alpha=1.0,
                 target_entropy=None,
                 embed_dim=256,
                 augmentation=False,
                 aug_num=2,
                 max_shift=4,
                 encoder_type="impala",
                 **kwargs):
        """Initialize variables of this model.

        Extra model kwargs:
            actor_hidden_activation (str): activation for actor network
            actor_hiddens (list): hidden layers sizes for actor network
            critic_hidden_activation (str): activation for critic network
            critic_hiddens (list): hidden layers sizes for critic network
            twin_q (bool): build twin Q networks.
            initial_alpha (float): The initial value for the to-be-optimized
                alpha parameter (default: 1.0).
            target_entropy (Optional[float]): An optional fixed value for the
                SAC alpha loss term. None or "auto" for automatic calculation
                of this value according to [1] (cont. actions) or [2]
                (discrete actions).

        Note that the core layers for forward() are not defined here, this
        only defines the layers for the output heads. Those layers for
        forward() should be defined in subclasses of SACModel.
        """
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.action_dim = action_space.n
        self.discrete = True
        self.action_outs = q_outs = self.action_dim
        self.action_ins = None  # No action inputs for the discrete case.
        self.embed_dim = embed_dim

        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding
        self.encoder = make_encoder(encoder_type,
                                    shape,
                                    out_features=embed_dim)

        # Build the policy network.
        self.action_model = nn.Sequential()
        ins = embed_dim
        act = get_activation_fn(actor_hidden_activation, framework="torch")
        init = nn.init.xavier_uniform_

        for i, n in enumerate(actor_hiddens):
            self.action_model.add_module(
                "action_{}".format(i),
                SlimFC(ins, n, initializer=init, activation_fn=act))
            ins = n
        self.action_model.add_module(
            "action_out",
            SlimFC(ins, self.action_outs, initializer=init,
                   activation_fn=None))

        # Build the Q-net(s), including target Q-net(s).
        def build_q_net(name_):
            act = get_activation_fn(critic_hidden_activation,
                                    framework="torch")
            init = nn.init.xavier_uniform_
            # For discrete actions, only obs.
            q_net = nn.Sequential()
            ins = embed_dim
            for i, n in enumerate(critic_hiddens):
                q_net.add_module(
                    "{}_hidden_{}".format(name_, i),
                    SlimFC(ins, n, initializer=init, activation_fn=act))
                ins = n

            q_net.add_module(
                "{}_out".format(name_),
                SlimFC(ins, q_outs, initializer=init, activation_fn=None))
            return q_net

        self.q_net = build_q_net("q")
        if twin_q:
            self.twin_q_net = build_q_net("twin_q")
        else:
            self.twin_q_net = None

        # temperature tensor
        self.log_alpha = torch.tensor(data=[np.log(initial_alpha)],
                                      dtype=torch.float32,
                                      requires_grad=True)

        # Auto-calculate the target entropy.
        if target_entropy is None or target_entropy == "auto":
            # See hyperparams in [2] (README.md).
            target_entropy = 0.98 * np.array(-np.log(1.0 / action_space.n),
                                             dtype=np.float32)

        self.target_entropy = torch.tensor(data=[target_entropy],
                                           dtype=torch.float32,
                                           requires_grad=False)

        # NOTE: custom fields
        self.augmentation = augmentation
        self.aug_num = aug_num
        # NOTE: augmentation
        if augmentation:
            obs_shape = obs_space.shape[-2]
            self.trans = nn.Sequential(nn.ReplicationPad2d(max_shift),
                                       RandomCrop((obs_shape, obs_shape)))
예제 #8
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        self.original_space = (obs_space.original_space if hasattr(
            obs_space, "original_space") else obs_space)

        self.processed_obs_space = (
            self.original_space
            if model_config.get("_disable_preprocessor_api") else obs_space)

        nn.Module.__init__(self)
        TorchModelV2.__init__(self, self.original_space, action_space,
                              num_outputs, model_config, name)

        self.flattened_input_space = flatten_space(self.original_space)

        # Atari type CNNs or IMPALA type CNNs (with residual layers)?
        # self.cnn_type = self.model_config["custom_model_config"].get(
        #     "conv_type", "atari")

        # Build the CNN(s) given obs_space's image components.
        self.cnns = {}
        self.one_hot = {}
        self.flatten_dims = {}
        self.flatten = {}
        concat_size = 0
        for i, component in enumerate(self.flattened_input_space):
            # Image space.
            if len(component.shape) == 3:
                config = {
                    "conv_filters":
                    model_config["conv_filters"] if "conv_filters"
                    in model_config else get_filter_config(component.shape),
                    "conv_activation":
                    model_config.get("conv_activation"),
                    "post_fcnet_hiddens": [],
                }
                # if self.cnn_type == "atari":
                self.cnns[i] = ModelCatalog.get_model_v2(
                    component,
                    action_space,
                    num_outputs=None,
                    model_config=config,
                    framework="torch",
                    name="cnn_{}".format(i),
                )
                # TODO (sven): add IMPALA-style option.
                # else:
                #    cnn = TorchImpalaVisionNet(
                #        component,
                #        action_space,
                #        num_outputs=None,
                #        model_config=config,
                #        name="cnn_{}".format(i))

                concat_size += self.cnns[i].num_outputs
                self.add_module("cnn_{}".format(i), self.cnns[i])
            # Discrete|MultiDiscrete inputs -> One-hot encode.
            elif isinstance(component, (Discrete, MultiDiscrete)):
                if isinstance(component, Discrete):
                    size = component.n
                else:
                    size = np.sum(component.nvec)
                config = {
                    "fcnet_hiddens": model_config["fcnet_hiddens"],
                    "fcnet_activation": model_config.get("fcnet_activation"),
                    "post_fcnet_hiddens": [],
                }
                self.one_hot[i] = ModelCatalog.get_model_v2(
                    Box(-1.0, 1.0, (size, ), np.float32),
                    action_space,
                    num_outputs=None,
                    model_config=config,
                    framework="torch",
                    name="one_hot_{}".format(i),
                )
                concat_size += self.one_hot[i].num_outputs
            # Everything else (1D Box).
            else:
                size = int(np.product(component.shape))
                config = {
                    "fcnet_hiddens": model_config["fcnet_hiddens"],
                    "fcnet_activation": model_config.get("fcnet_activation"),
                    "post_fcnet_hiddens": [],
                }
                self.flatten[i] = ModelCatalog.get_model_v2(
                    Box(-1.0, 1.0, (size, ), np.float32),
                    action_space,
                    num_outputs=None,
                    model_config=config,
                    framework="torch",
                    name="flatten_{}".format(i),
                )
                self.flatten_dims[i] = size
                concat_size += self.flatten[i].num_outputs

        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation",
                                                 "relu"),
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(
            Box(float("-inf"),
                float("inf"),
                shape=(concat_size, ),
                dtype=np.float32),
            self.action_space,
            None,
            post_fc_stack_config,
            framework="torch",
            name="post_fc_stack",
        )

        # Actions and value heads.
        self.logits_layer = None
        self.value_layer = None
        self._value_out = None

        if num_outputs:
            # Action-distribution head.
            self.logits_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=num_outputs,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01),
            )
            # Create the value branch model.
            self.value_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=1,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01),
            )
        else:
            self.num_outputs = concat_size
예제 #9
0
 def __init__(self, obs_space, action_space, num_outputs, model_config,
              name):
     TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                           model_config, name)
     nn.Module.__init__(self)
예제 #10
0
 def __init__(self, *args, **kwargs):  # type: ignore
     TorchModelV2.__init__(self, *args, **kwargs)
     nn.Module.__init__(self)
     self._cur_value = None
예제 #11
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        if action_space != Tuple([Discrete(2), Discrete(2)]):
            raise ValueError(
                "This model only supports the [2, 2] action space")

        # Output of the model (normally 'logits', but for an autoregressive
        # dist this is more like a context/feature layer encoding the obs)
        self.context_layer = SlimFC(
            in_size=obs_space.shape[0],
            out_size=num_outputs,
            initializer=normc_init_torch(1.0),
            activation_fn=nn.Tanh,
        )

        # V(s)
        self.value_branch = SlimFC(
            in_size=num_outputs,
            out_size=1,
            initializer=normc_init_torch(0.01),
            activation_fn=None,
        )

        # P(a1 | obs)
        self.a1_logits = SlimFC(
            in_size=num_outputs,
            out_size=2,
            activation_fn=None,
            initializer=normc_init_torch(0.01),
        )

        class _ActionModel(nn.Module):
            def __init__(self):
                nn.Module.__init__(self)
                self.a2_hidden = SlimFC(
                    in_size=1,
                    out_size=16,
                    activation_fn=nn.Tanh,
                    initializer=normc_init_torch(1.0),
                )
                self.a2_logits = SlimFC(
                    in_size=16,
                    out_size=2,
                    activation_fn=None,
                    initializer=normc_init_torch(0.01),
                )

            def forward(self_, ctx_input, a1_input):
                a1_logits = self.a1_logits(ctx_input)
                a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
                return a1_logits, a2_logits

        # P(a2 | a1)
        # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
        # a2_context = tf.keras.layers.Concatenate(axis=1)(
        #     [ctx_input, a1_input])
        self.action_module = _ActionModel()

        self._context = None
예제 #12
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):  #,
        #graph_layers, graph_features, graph_tabs, graph_edge_features, cnn_filters, value_cnn_filters, value_cnn_compression, cnn_compression, relative, activation):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.cfg = copy.deepcopy(DEFAULT_OPTIONS)
        self.cfg.update(model_config['custom_model_config'])

        #self.cfg = model_config['custom_options']
        self.n_agents = len(obs_space.original_space['agents'])
        self.graph_features = self.cfg['graph_features']
        self.cnn_compression = self.cfg['cnn_compression']
        self.activation = {
            'relu': nn.ReLU,
            'leakyrelu': nn.LeakyReLU
        }[self.cfg['activation']]

        layers = []
        input_shape = obs_space.original_space['agents'][0]['map'].shape
        (w, h, in_channels) = input_shape

        in_size = [w, h]
        for out_channels, kernel, stride in self.cfg['cnn_filters'][:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=self.activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = self.cfg['cnn_filters'][-1]
        layers.append(
            SlimConv2d(in_channels, out_channels, kernel, stride, None))
        layers.append(nn.Flatten(1, -1))
        #if isinstance(cnn_compression, int):
        #    layers.append(nn.Linear(cnn_compression, self.cfg['graph_features']-2)) # reserve 2 for pos
        #    layers.append(self.activation{))
        self.coop_convs = nn.Sequential(*layers)
        self.greedy_convs = copy.deepcopy(self.coop_convs)

        self.coop_value_obs_convs = copy.deepcopy(self.coop_convs)
        self.greedy_value_obs_convs = copy.deepcopy(self.coop_convs)

        summary(self.coop_convs,
                device="cpu",
                input_size=(input_shape[2], input_shape[0], input_shape[1]))

        gfl = []
        for i in range(self.cfg['graph_layers']):
            gfl.append(
                gml_adv.GraphFilterBatchGSOA(self.graph_features,
                                             self.graph_features,
                                             self.cfg['graph_tabs'],
                                             self.cfg['agent_split'],
                                             self.cfg['graph_edge_features'],
                                             False))
            #gfl.append(gml.GraphFilterBatchGSO(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['graph_edge_features'], False))
            gfl.append(self.activation())

        self.GFL = nn.Sequential(*gfl)

        #gso_sum = torch.zeros(2, 1, 8, 8)
        #self.GFL[0].addGSO(gso_sum)
        #summary(self.GFL, device="cuda" if torch.cuda.is_available() else "cpu", input_size=(self.graph_features, 8))

        logits_inp_features = self.graph_features
        if self.cfg['cnn_residual']:
            logits_inp_features += self.cnn_compression

        post_logits = [
            nn.Linear(logits_inp_features, 64),
            self.activation(),
            nn.Linear(64, 32),
            self.activation()
        ]
        logit_linear = nn.Linear(32, 5)
        nn.init.xavier_uniform_(logit_linear.weight)
        nn.init.constant_(logit_linear.bias, 0)
        post_logits.append(logit_linear)
        self.coop_logits = nn.Sequential(*post_logits)
        self.greedy_logits = copy.deepcopy(self.coop_logits)
        summary(self.coop_logits,
                device="cpu",
                input_size=(logits_inp_features, ))

        ##############################

        layers = []
        input_shape = np.array(obs_space.original_space['state'].shape)
        (w, h, in_channels) = input_shape

        in_size = [w, h]
        for out_channels, kernel, stride in self.cfg['value_cnn_filters'][:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=self.activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = self.cfg['value_cnn_filters'][-1]
        layers.append(
            SlimConv2d(in_channels, out_channels, kernel, stride, None))
        layers.append(nn.Flatten(1, -1))

        self.coop_value_cnn = nn.Sequential(*layers)
        self.greedy_value_cnn = copy.deepcopy(self.coop_value_cnn)
        summary(self.greedy_value_cnn,
                device="cpu",
                input_size=(input_shape[2], input_shape[0], input_shape[1]))

        layers = [
            nn.Linear(self.cnn_compression + self.cfg['value_cnn_compression'],
                      64),
            self.activation(),
            nn.Linear(64, 32),
            self.activation()
        ]
        values_linear = nn.Linear(32, 1)
        normc_initializer()(values_linear.weight)
        nn.init.constant_(values_linear.bias, 0)
        layers.append(values_linear)

        self.coop_value_branch = nn.Sequential(*layers)
        self.greedy_value_branch = copy.deepcopy(self.coop_value_branch)
        summary(self.coop_value_branch,
                device="cpu",
                input_size=(self.cnn_compression +
                            self.cfg['value_cnn_compression'], ))

        self._cur_value = None

        self.freeze_coop_value(self.cfg['freeze_coop_value'])
        self.freeze_greedy_value(self.cfg['freeze_greedy_value'])
        self.freeze_coop(self.cfg['freeze_coop'])
        self.freeze_greedy(self.cfg['freeze_greedy'])
예제 #13
0
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, custom_input_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(custom_input_space, action_space, num_outputs, model_config, name)
예제 #14
0
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):

        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        assert len(filters) > 0,\
            "Must provide at least 1 entry in `conv_filters`!"
        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None
        self.traj_view_framestacking = False

        layers = []
        # Perform Atari framestacking via traj. view API.
        if model_config.get("num_framestacks") != "auto" and \
                model_config.get("num_framestacks", 0) > 1:
            (w, h) = obs_space.shape
            in_channels = model_config["num_framestacks"]
            self.traj_view_framestacking = True
        else:
            (w, h, in_channels) = obs_space.shape

        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer is a Conv2D and uses num_outputs.
        if no_final_linear and num_outputs:
            layers.append(
                SlimConv2d(
                    in_channels,
                    num_outputs,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))
            out_channels = num_outputs
        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride)
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                self._logits = SlimConv2d(out_channels,
                                          num_outputs, [1, 1],
                                          1,
                                          padding,
                                          activation_fn=None)
            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())
                self.num_outputs = out_channels

        self._convs = nn.Sequential(*layers)

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(out_channels,
                                        1,
                                        initializer=normc_initializer(0.01),
                                        activation_fn=None)
        else:
            vf_layers = []
            if self.traj_view_framestacking:
                (w, h) = obs_space.shape
                in_channels = model_config["num_framestacks"]
            else:
                (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel,
                                                 [stride, stride])
                vf_layers.append(
                    SlimConv2d(in_channels,
                               out_channels,
                               kernel,
                               stride,
                               padding,
                               activation_fn=activation))
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           None,
                           activation_fn=activation))

            vf_layers.append(
                SlimConv2d(in_channels=out_channels,
                           out_channels=1,
                           kernel=1,
                           stride=1,
                           padding=None,
                           activation_fn=None))
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None

        # Optional: framestacking obs/new_obs for Atari.
        if self.traj_view_framestacking:
            from_ = model_config["num_framestacks"] - 1
            self.view_requirements[SampleBatch.OBS].shift = \
                "-{}:0".format(from_)
            self.view_requirements[SampleBatch.OBS].shift_from = -from_
            self.view_requirements[SampleBatch.OBS].shift_to = 0
            self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
                data_col=SampleBatch.OBS,
                shift="-{}:1".format(from_ - 1),
                space=self.view_requirements[SampleBatch.OBS].space,
            )
예제 #15
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        obs_space_ = obs_space.original_space
        data, images, privates = obs_space_.spaces['data'], obs_space_.spaces['images'], \
                                 obs_space_.spaces['privates']

        N, T, L = data.shape
        adjusted_data_shape = (T, N * L)
        _, w, h, c = images.shape
        shape = (c * N, w, h)
        self.img_shape = shape

        conv_filters = model_config.get('conv_filters')
        activation = model_config.get("fcnet_activation")
        hiddens = model_config.get("fcnet_hiddens", [100, 100])
        lstm_dim = model_config.get("lstm_cell_size", 128)

        if not conv_filters:
            conv_filters = [16, 32, 32]

        max_pool = [3] * len(conv_filters)

        conv_seqs = []

        self.lstm_net = LSTM(input_dim=adjusted_data_shape[-1],
                             hidden_dim=lstm_dim,
                             num_layers=2)
        for (out_channels, mp) in zip(conv_filters, max_pool):
            conv_seq = ResNet(shape, out_channels, mp)
            shape = conv_seq.get_output_shape()
            conv_seqs.append(conv_seq)
        conv_seqs.append(nn.Flatten())
        self.conv_seqs = nn.ModuleList(conv_seqs)

        prev_layer_size = lstm_dim + int(np.product(privates.shape)) + int(
            np.product(shape))

        layers = []
        for size in hiddens:
            layers.append(
                SlimFC(in_size=prev_layer_size,
                       out_size=size,
                       initializer=normc_initializer(1.0),
                       activation_fn=activation))
            prev_layer_size = size

        self._hidden_layers = nn.Sequential(*layers)
        self._features = None

        self._policy_net = SlimFC(in_size=prev_layer_size,
                                  out_size=num_outputs,
                                  initializer=normc_initializer(1.0),
                                  activation_fn=activation)

        self._value_net = SlimFC(in_size=prev_layer_size,
                                 out_size=1,
                                 initializer=normc_initializer(1.0),
                                 activation_fn=activation)
예제 #16
0
파일: fcnet.py 프로젝트: yuishihara/ray
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = get_activation_fn(model_config.get("fcnet_activation"),
                                       framework="torch")
        hiddens = model_config.get("fcnet_hiddens")
        no_final_linear = model_config.get("no_final_linear")

        # TODO(sven): implement case: vf_shared_layers = False.
        # vf_share_layers = model_config.get("vf_share_layers")

        logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
        layers = []
        prev_layer_size = int(np.product(obs_space.shape))
        self._logits = None

        # Create layers 0 to second-last.
        for size in hiddens[:-1]:
            layers.append(
                SlimFC(in_size=prev_layer_size,
                       out_size=size,
                       initializer=normc_initializer(1.0),
                       activation_fn=activation))
            prev_layer_size = size

        # The last layer is adjusted to be of size num_outputs, but it's a
        # layer with activation.
        if no_final_linear and self.num_outputs:
            layers.append(
                SlimFC(in_size=prev_layer_size,
                       out_size=self.num_outputs,
                       initializer=normc_initializer(1.0),
                       activation_fn=activation))
            prev_layer_size = self.num_outputs
        # Finish the layers with the provided sizes (`hiddens`), plus -
        # iff num_outputs > 0 - a last linear layer of size num_outputs.
        else:
            if len(hiddens) > 0:
                layers.append(
                    SlimFC(in_size=prev_layer_size,
                           out_size=hiddens[-1],
                           initializer=normc_initializer(1.0),
                           activation_fn=activation))
                prev_layer_size = hiddens[-1]
            if self.num_outputs:
                self._logits = SlimFC(in_size=prev_layer_size,
                                      out_size=self.num_outputs,
                                      initializer=normc_initializer(0.01),
                                      activation_fn=None)
            else:
                self.num_outputs = ([np.product(obs_space.shape)] +
                                    hiddens[-1:-1])[-1]

        self._hidden_layers = nn.Sequential(*layers)

        # TODO(sven): Implement non-shared value branch.
        self._value_branch = SlimFC(in_size=prev_layer_size,
                                    out_size=1,
                                    initializer=normc_initializer(1.0),
                                    activation_fn=None)
        # Holds the current value output.
        self._cur_value = None
예제 #17
0
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):

        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        assert len(filters) > 0,\
            "Must provide at least 1 entry in `conv_filters`!"

        # Post FC net config.
        post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
        post_fcnet_activation = get_activation_fn(
            model_config.get("post_fcnet_activation"), framework="torch")

        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None

        layers = []
        (w, h, in_channels) = obs_space.shape

        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, stride)
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer has activation function and exits with
        # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
        # on `post_fcnet_...` settings).
        if no_final_linear and num_outputs:
            out_channels = out_channels if post_fcnet_hiddens else num_outputs
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # Add (optional) post-fc-stack after last Conv2D layer.
            layer_sizes = post_fcnet_hiddens[:-1] + (
                [num_outputs] if post_fcnet_hiddens else [])
            for i, out_size in enumerate(layer_sizes):
                layers.append(
                    SlimFC(in_size=out_channels,
                           out_size=out_size,
                           activation_fn=post_fcnet_activation,
                           initializer=normc_initializer(1.0)))
                out_channels = out_size

        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride)
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                if post_fcnet_hiddens:
                    layers.append(nn.Flatten())
                    in_size = out_channels
                    # Add (optional) post-fc-stack after last Conv2D layer.
                    for i, out_size in enumerate(post_fcnet_hiddens +
                                                 [num_outputs]):
                        layers.append(
                            SlimFC(in_size=in_size,
                                   out_size=out_size,
                                   activation_fn=post_fcnet_activation if
                                   i < len(post_fcnet_hiddens) - 1 else None,
                                   initializer=normc_initializer(1.0)))
                        in_size = out_size
                    # Last layer is logits layer.
                    self._logits = layers.pop()

                else:
                    self._logits = SlimConv2d(out_channels,
                                              num_outputs, [1, 1],
                                              1,
                                              padding,
                                              activation_fn=None)

            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())

        self._convs = nn.Sequential(*layers)

        # If our num_outputs still unknown, we need to do a test pass to
        # figure out the output dimensions. This could be the case, if we have
        # the Flatten layer at the end.
        if self.num_outputs is None:
            # Create a B=1 dummy sample and push it through out conv-net.
            dummy_in = torch.from_numpy(self.obs_space.sample()).permute(
                2, 0, 1).unsqueeze(0).float()
            dummy_out = self._convs(dummy_in)
            self.num_outputs = dummy_out.shape[1]

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(out_channels,
                                        1,
                                        initializer=normc_initializer(0.01),
                                        activation_fn=None)
        else:
            vf_layers = []
            (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel, stride)
                vf_layers.append(
                    SlimConv2d(in_channels,
                               out_channels,
                               kernel,
                               stride,
                               padding,
                               activation_fn=activation))
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           None,
                           activation_fn=activation))

            vf_layers.append(
                SlimConv2d(in_channels=out_channels,
                           out_channels=1,
                           kernel=1,
                           stride=1,
                           padding=None,
                           activation_fn=None))
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None