def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) filters = model_config.get("conv_filters") if not filters: filters = _get_filter_config(obs_space.shape) layers = [] (w, h, in_channels) = obs_space.shape in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: padding, out_size = valid_padding(in_size, kernel, [stride, stride]) layers.append( SlimConv2d(in_channels, out_channels, kernel, stride, padding)) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] layers.append( SlimConv2d(in_channels, out_channels, kernel, stride, None)) self._convs = nn.Sequential(*layers) self._logits = SlimFC(out_channels, num_outputs, initializer=nn.init.xavier_uniform_) self._value_branch = SlimFC(out_channels, 1, initializer=normc_initializer()) self._cur_value = None
def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) activation = get_activation_fn( model_config.get("conv_activation"), framework="torch") filters = model_config.get("conv_filters") if not filters: filters = _get_filter_config(obs_space.shape) # no_final_linear = model_config.get("no_final_linear") # vf_share_layers = model_config.get("vf_share_layers") layers = [] (w, h, in_channels) = obs_space.shape in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: padding, out_size = valid_padding(in_size, kernel, [stride, stride]) layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, padding, activation_fn=activation)) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, None, activation_fn=activation)) self._convs = nn.Sequential(*layers) self._logits = SlimFC( out_channels, num_outputs, initializer=nn.init.xavier_uniform_) self._value_branch = SlimFC( out_channels, 1, initializer=normc_initializer()) # Holds the current "base" output (before logits layer). self._features = None