예제 #1
0
    def __init__(self, filters, num_outputs, input_shape):
        super(VisionNet, self).__init__()
        layers = []
        (w, h, in_channels) = input_shape
        in_size = [w, h]
        layers = []
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn="relu"))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]
        # final FC layer.
        layers.append(
            SlimConv2d(
                32,
                num_outputs,
                [1, 1],
                stride,
                None,  # padding=valid
                activation_fn='relu'))
        # squeeze into [B, num_outputs]
        layers.append(nn.Flatten(start_dim=1, end_dim=3))
        # Put everything in sequence.
        self._model = nn.Sequential(*layers)
예제 #2
0
 def __init__(self, obs_space, action_space, num_outputs, model_config,
              name):
     num_options = model_config.get('oc_num_options')
     TorchModelV2.__init__(self, obs_space, action_space, num_outputs * num_options,
                           model_config, name)
     nn.Module.__init__(self)
     layers = []
     (w, h, in_channels) = obs_space.shape
     in_size = [w, h]
     # Convolutional layers
     for out_channels, kernel, stride in OCNET_FILTERS:
         padding, out_size = same_padding(in_size, kernel, [stride, stride])
         layers.append(nn.Conv2d(in_channels, out_channels, kernel, stride, padding))
         layers.append(nn.ReLU())
         in_channels = out_channels
         in_size = out_size
     # Dense layer after flattening output, using ReLU
     hSize = OCNET_DENSE
     self.option_epsilon = model_config.get('oc_option_epsilon')
     layers.append(nn.Flatten())
     layers.append(nn.Linear(in_size, hSize))
     layers.append(nn.ReLU())
     self._convs = nn.Sequential(*layers)
     # q, pi, beta, and v
     self.q = nn.Linear(hSize, num_options)  # Value for each option
     #self.v = nn.Linear(hSize, 1)  # Value for state alone? Or do
     self.pi = nn.Sequential(nn.Linear(hSize, num_options * num_outputs), View((num_options, num_outputs)), nn.Softmax(dim=-1))  # Action probabilities for each option
     self.beta = nn.Sequential(nn.Linear(hSize, num_options), nn.Sigmoid)  # Termination probabilities
     # Holds the current "base" output (before heads).
     self._features = self._q = self._v = self._pi = self._beta = None
예제 #3
0
    def __init__(self, activation, in_size, kernel, stride, output_size, channels=256):
        nn.Module.__init__(self)

        self.activation = activation
        self.channels = channels
        self.output_size = output_size

        in_size = [
            np.ceil((in_size[0] - kernel[0]) / stride),
            np.ceil((in_size[1] - kernel[1]) / stride)
        ]
        padding, _ = same_padding(in_size, [1, 1], [1, 1])

        self.layer1 = SlimConv2d(
            self.channels,
            self.channels,
            kernel=1,  # change to different value when representation function is changed?
            stride=1,
            padding=None,
            activation_fn=self.activation)

        self.layer2 = SlimConv2d(
                    self.channels,
                    self.output_size,
                    [1, 1],
                    1,
                    padding,
                    activation_fn=None)

        self.policy = nn.Sequential(self.layer1, self.layer2, nn.Flatten(), nn.Softmax())

        self.vlayer1 = SlimConv2d(
            self.channels,
            self.channels,
            kernel=1,  # change to different value when representation function is changed?
            stride=1,
            padding=None,
            activation_fn=self.activation)

        self.vlayer2 = SlimFC(
                self.channels,
                1,
                activation_fn=None)

        self.value = nn.Sequential(self.vlayer1, nn.Flatten(), self.vlayer2)
예제 #4
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None

        layers = []
        (w, h, in_channels) = obs_space.shape
        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer is a Conv2D and uses num_outputs.
        if no_final_linear and num_outputs:
            layers.append(
                SlimConv2d(
                    in_channels,
                    num_outputs,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))
            out_channels = num_outputs
        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride)
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                self._logits = SlimConv2d(out_channels,
                                          num_outputs, [1, 1],
                                          1,
                                          padding,
                                          activation_fn=None)
            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())
                self.num_outputs = out_channels

        self._convs = nn.Sequential(*layers)

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(out_channels,
                                        1,
                                        initializer=normc_initializer(0.01),
                                        activation_fn=None)
        else:
            vf_layers = []
            (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel,
                                                 [stride, stride])
                vf_layers.append(
                    SlimConv2d(in_channels,
                               out_channels,
                               kernel,
                               stride,
                               padding,
                               activation_fn=activation))
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           None,
                           activation_fn=activation))

            vf_layers.append(
                SlimConv2d(in_channels=out_channels,
                           out_channels=1,
                           kernel=1,
                           stride=1,
                           padding=None,
                           activation_fn=None))
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None
예제 #5
0
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
    ):

        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)

        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!"

        # Post FC net config.
        post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
        post_fcnet_activation = get_activation_fn(
            model_config.get("post_fcnet_activation"), framework="torch"
        )

        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None

        layers = []
        (w, h, in_channels) = obs_space.shape

        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, stride)
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    padding,
                    activation_fn=activation,
                )
            )
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer has activation function and exits with
        # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
        # on `post_fcnet_...` settings).
        if no_final_linear and num_outputs:
            out_channels = out_channels if post_fcnet_hiddens else num_outputs
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation,
                )
            )

            # Add (optional) post-fc-stack after last Conv2D layer.
            layer_sizes = post_fcnet_hiddens[:-1] + (
                [num_outputs] if post_fcnet_hiddens else []
            )
            for i, out_size in enumerate(layer_sizes):
                layers.append(
                    SlimFC(
                        in_size=out_channels,
                        out_size=out_size,
                        activation_fn=post_fcnet_activation,
                        initializer=normc_initializer(1.0),
                    )
                )
                out_channels = out_size

        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation,
                )
            )

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride),
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                if post_fcnet_hiddens:
                    layers.append(nn.Flatten())
                    in_size = out_channels
                    # Add (optional) post-fc-stack after last Conv2D layer.
                    for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]):
                        layers.append(
                            SlimFC(
                                in_size=in_size,
                                out_size=out_size,
                                activation_fn=post_fcnet_activation
                                if i < len(post_fcnet_hiddens) - 1
                                else None,
                                initializer=normc_initializer(1.0),
                            )
                        )
                        in_size = out_size
                    # Last layer is logits layer.
                    self._logits = layers.pop()

                else:
                    self._logits = SlimConv2d(
                        out_channels,
                        num_outputs,
                        [1, 1],
                        1,
                        padding,
                        activation_fn=None,
                    )

            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())

        self._convs = nn.Sequential(*layers)

        # If our num_outputs still unknown, we need to do a test pass to
        # figure out the output dimensions. This could be the case, if we have
        # the Flatten layer at the end.
        if self.num_outputs is None:
            # Create a B=1 dummy sample and push it through out conv-net.
            dummy_in = (
                torch.from_numpy(self.obs_space.sample())
                .permute(2, 0, 1)
                .unsqueeze(0)
                .float()
            )
            dummy_out = self._convs(dummy_in)
            self.num_outputs = dummy_out.shape[1]

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(
                out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None
            )
        else:
            vf_layers = []
            (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel, stride)
                vf_layers.append(
                    SlimConv2d(
                        in_channels,
                        out_channels,
                        kernel,
                        stride,
                        padding,
                        activation_fn=activation,
                    )
                )
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,
                    activation_fn=activation,
                )
            )

            vf_layers.append(
                SlimConv2d(
                    in_channels=out_channels,
                    out_channels=1,
                    kernel=1,
                    stride=1,
                    padding=None,
                    activation_fn=None,
                )
            )
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None
예제 #6
0
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):

        if not model_config.get("conv_filters"):
            model_config["conv_filters"] = get_filter_config(obs_space.shape)

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        activation = self.model_config.get("conv_activation")
        filters = self.model_config["conv_filters"]
        assert len(filters) > 0,\
            "Must provide at least 1 entry in `conv_filters`!"

        # Post FC net config.
        post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
        post_fcnet_activation = get_activation_fn(
            model_config.get("post_fcnet_activation"), framework="torch")

        no_final_linear = self.model_config.get("no_final_linear")
        vf_share_layers = self.model_config.get("vf_share_layers")

        # Whether the last layer is the output of a Flattened (rather than
        # a n x (1,1) Conv2D).
        self.last_layer_is_flattened = False
        self._logits = None
        self.traj_view_framestacking = False

        layers = []
        # Perform Atari framestacking via traj. view API.
        if model_config.get("num_framestacks") != "auto" and \
                model_config.get("num_framestacks", 0) > 1:
            (w, h) = obs_space.shape
            in_channels = model_config["num_framestacks"]
            self.traj_view_framestacking = True
        else:
            (w, h, in_channels) = obs_space.shape

        in_size = [w, h]
        for out_channels, kernel, stride in filters[:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = filters[-1]

        # No final linear: Last layer has activation function and exits with
        # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
        # on `post_fcnet_...` settings).
        if no_final_linear and num_outputs:
            out_channels = out_channels if post_fcnet_hiddens else num_outputs
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # Add (optional) post-fc-stack after last Conv2D layer.
            layer_sizes = post_fcnet_hiddens[:-1] + (
                [num_outputs] if post_fcnet_hiddens else [])
            for i, out_size in enumerate(layer_sizes):
                layers.append(
                    SlimFC(in_size=out_channels,
                           out_size=out_size,
                           activation_fn=post_fcnet_activation,
                           initializer=normc_initializer(1.0)))
                out_channels = out_size

        # Finish network normally (w/o overriding last layer size with
        # `num_outputs`), then add another linear one of size `num_outputs`.
        else:
            layers.append(
                SlimConv2d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride,
                    None,  # padding=valid
                    activation_fn=activation))

            # num_outputs defined. Use that to create an exact
            # `num_output`-sized (1,1)-Conv2D.
            if num_outputs:
                in_size = [
                    np.ceil((in_size[0] - kernel[0]) / stride),
                    np.ceil((in_size[1] - kernel[1]) / stride)
                ]
                padding, _ = same_padding(in_size, [1, 1], [1, 1])
                if post_fcnet_hiddens:
                    layers.append(nn.Flatten())
                    in_size = out_channels
                    # Add (optional) post-fc-stack after last Conv2D layer.
                    for i, out_size in enumerate(post_fcnet_hiddens +
                                                 [num_outputs]):
                        layers.append(
                            SlimFC(in_size=in_size,
                                   out_size=out_size,
                                   activation_fn=post_fcnet_activation if
                                   i < len(post_fcnet_hiddens) - 1 else None,
                                   initializer=normc_initializer(1.0)))
                        in_size = out_size
                    # Last layer is logits layer.
                    self._logits = layers.pop()

                else:
                    self._logits = SlimConv2d(out_channels,
                                              num_outputs, [1, 1],
                                              1,
                                              padding,
                                              activation_fn=None)

            # num_outputs not known -> Flatten, then set self.num_outputs
            # to the resulting number of nodes.
            else:
                self.last_layer_is_flattened = True
                layers.append(nn.Flatten())
                self.num_outputs = out_channels

        self._convs = nn.Sequential(*layers)

        # Build the value layers
        self._value_branch_separate = self._value_branch = None
        if vf_share_layers:
            self._value_branch = SlimFC(out_channels,
                                        1,
                                        initializer=normc_initializer(0.01),
                                        activation_fn=None)
        else:
            vf_layers = []
            if self.traj_view_framestacking:
                (w, h) = obs_space.shape
                in_channels = model_config["num_framestacks"]
            else:
                (w, h, in_channels) = obs_space.shape
            in_size = [w, h]
            for out_channels, kernel, stride in filters[:-1]:
                padding, out_size = same_padding(in_size, kernel,
                                                 [stride, stride])
                vf_layers.append(
                    SlimConv2d(in_channels,
                               out_channels,
                               kernel,
                               stride,
                               padding,
                               activation_fn=activation))
                in_channels = out_channels
                in_size = out_size

            out_channels, kernel, stride = filters[-1]
            vf_layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           None,
                           activation_fn=activation))

            vf_layers.append(
                SlimConv2d(in_channels=out_channels,
                           out_channels=1,
                           kernel=1,
                           stride=1,
                           padding=None,
                           activation_fn=None))
            self._value_branch_separate = nn.Sequential(*vf_layers)

        # Holds the current "base" output (before logits layer).
        self._features = None

        # Optional: framestacking obs/new_obs for Atari.
        if self.traj_view_framestacking:
            from_ = model_config["num_framestacks"] - 1
            self.view_requirements[SampleBatch.OBS].shift = \
                "-{}:0".format(from_)
            self.view_requirements[SampleBatch.OBS].shift_from = -from_
            self.view_requirements[SampleBatch.OBS].shift_to = 0
            self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
                data_col=SampleBatch.OBS,
                shift="-{}:1".format(from_ - 1),
                space=self.view_requirements[SampleBatch.OBS].space,
            )
예제 #7
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):  #,
        #graph_layers, graph_features, graph_tabs, graph_edge_features, cnn_filters, value_cnn_filters, value_cnn_compression, cnn_compression, relative, activation):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.cfg = copy.deepcopy(DEFAULT_OPTIONS)
        self.cfg.update(model_config['custom_model_config'])

        #self.cfg = model_config['custom_options']
        self.n_agents = len(obs_space.original_space['agents'])
        self.graph_features = self.cfg['graph_features']
        self.cnn_compression = self.cfg['cnn_compression']
        self.activation = {
            'relu': nn.ReLU,
            'leakyrelu': nn.LeakyReLU
        }[self.cfg['activation']]

        layers = []
        input_shape = obs_space.original_space['agents'][0]['map'].shape
        (w, h, in_channels) = input_shape

        in_size = [w, h]
        for out_channels, kernel, stride in self.cfg['cnn_filters'][:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=self.activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = self.cfg['cnn_filters'][-1]
        layers.append(
            SlimConv2d(in_channels, out_channels, kernel, stride, None))
        layers.append(nn.Flatten(1, -1))
        #if isinstance(cnn_compression, int):
        #    layers.append(nn.Linear(cnn_compression, self.cfg['graph_features']-2)) # reserve 2 for pos
        #    layers.append(self.activation{))
        self.coop_convs = nn.Sequential(*layers)
        self.greedy_convs = copy.deepcopy(self.coop_convs)

        self.coop_value_obs_convs = copy.deepcopy(self.coop_convs)
        self.greedy_value_obs_convs = copy.deepcopy(self.coop_convs)

        summary(self.coop_convs,
                device="cpu",
                input_size=(input_shape[2], input_shape[0], input_shape[1]))

        gfl = []
        for i in range(self.cfg['graph_layers']):
            gfl.append(
                gml_adv.GraphFilterBatchGSOA(self.graph_features,
                                             self.graph_features,
                                             self.cfg['graph_tabs'],
                                             self.cfg['agent_split'],
                                             self.cfg['graph_edge_features'],
                                             False))
            #gfl.append(gml.GraphFilterBatchGSO(self.graph_features, self.graph_features, self.cfg['graph_tabs'], self.cfg['graph_edge_features'], False))
            gfl.append(self.activation())

        self.GFL = nn.Sequential(*gfl)

        #gso_sum = torch.zeros(2, 1, 8, 8)
        #self.GFL[0].addGSO(gso_sum)
        #summary(self.GFL, device="cuda" if torch.cuda.is_available() else "cpu", input_size=(self.graph_features, 8))

        logits_inp_features = self.graph_features
        if self.cfg['cnn_residual']:
            logits_inp_features += self.cnn_compression

        post_logits = [
            nn.Linear(logits_inp_features, 64),
            self.activation(),
            nn.Linear(64, 32),
            self.activation()
        ]
        logit_linear = nn.Linear(32, 5)
        nn.init.xavier_uniform_(logit_linear.weight)
        nn.init.constant_(logit_linear.bias, 0)
        post_logits.append(logit_linear)
        self.coop_logits = nn.Sequential(*post_logits)
        self.greedy_logits = copy.deepcopy(self.coop_logits)
        summary(self.coop_logits,
                device="cpu",
                input_size=(logits_inp_features, ))

        ##############################

        layers = []
        input_shape = np.array(obs_space.original_space['state'].shape)
        (w, h, in_channels) = input_shape

        in_size = [w, h]
        for out_channels, kernel, stride in self.cfg['value_cnn_filters'][:-1]:
            padding, out_size = same_padding(in_size, kernel, [stride, stride])
            layers.append(
                SlimConv2d(in_channels,
                           out_channels,
                           kernel,
                           stride,
                           padding,
                           activation_fn=self.activation))
            in_channels = out_channels
            in_size = out_size

        out_channels, kernel, stride = self.cfg['value_cnn_filters'][-1]
        layers.append(
            SlimConv2d(in_channels, out_channels, kernel, stride, None))
        layers.append(nn.Flatten(1, -1))

        self.coop_value_cnn = nn.Sequential(*layers)
        self.greedy_value_cnn = copy.deepcopy(self.coop_value_cnn)
        summary(self.greedy_value_cnn,
                device="cpu",
                input_size=(input_shape[2], input_shape[0], input_shape[1]))

        layers = [
            nn.Linear(self.cnn_compression + self.cfg['value_cnn_compression'],
                      64),
            self.activation(),
            nn.Linear(64, 32),
            self.activation()
        ]
        values_linear = nn.Linear(32, 1)
        normc_initializer()(values_linear.weight)
        nn.init.constant_(values_linear.bias, 0)
        layers.append(values_linear)

        self.coop_value_branch = nn.Sequential(*layers)
        self.greedy_value_branch = copy.deepcopy(self.coop_value_branch)
        summary(self.coop_value_branch,
                device="cpu",
                input_size=(self.cnn_compression +
                            self.cfg['value_cnn_compression'], ))

        self._cur_value = None

        self.freeze_coop_value(self.cfg['freeze_coop_value'])
        self.freeze_greedy_value(self.cfg['freeze_greedy_value'])
        self.freeze_coop(self.cfg['freeze_coop'])
        self.freeze_greedy(self.cfg['freeze_greedy'])