def _make_policy_model(in_size, out_size): return nn.Sequential( make_module('linear', 'relu', in_size, 512), nn.ReLU(), make_module('linear', 'linear', 512, out_size), nn.LogSoftmax(dim=-1), )
def _make_v_model(in_size): return nn.Sequential( make_module('linear', 'relu', in_size, 512), nn.ReLU(), make_module('linear', 'relu', 512, 256), nn.ReLU(), make_module('linear', 'linear', 256, 1), )
def __init__( self, space: gym.spaces.Dict, names: Iterable[str], *, embedding_size: int, layers: List[int], ): super().__init__() self.space = space num_embeddings = max( space['grid'].high.max() + 1, space['item'].high.max() + 1, ) self.embedding = EmbeddingRepresentation(num_embeddings, embedding_size) gv_models = [self._make_gv_model(name) for name in names] self.cat_representation = CatRepresentation(gv_models) self.fc_model: nn.Module if len(layers) > 0: dims = [self.cat_representation.dim] + layers linear_modules = [ make_module('linear', 'relu', in_dim, out_dim) for in_dim, out_dim in mitt.pairwise(dims) ] relu_modules = [nn.ReLU() for _ in linear_modules] modules = mitt.interleave(linear_modules, relu_modules) self.fc_model = nn.Sequential(*modules) self._dim = dims[-1] else: self.fc_model = nn.Identity() self._dim = self.cat_representation.dim
def __init__(self, representation: Representation, dim: int): super().__init__() self._representation = representation self._resize_model = nn.Sequential( make_module('linear', 'relu', representation.dim, dim), nn.ReLU(), ) self._dim = dim
def __init__(self, input_space: gym.spaces.Box, dims: Sequence[int]): super().__init__() checkraise( isinstance(input_space, gym.spaces.Box) and len(input_space.shape) == 1, TypeError, 'input_space must be Box', ) checkraise( len(dims) > 0, ValueError, 'dims must be non-empty', ) (input_dim,) = input_space.shape self.dims = list(itt.chain([input_dim], dims)) modules = mitt.flatten( (make_module('linear', 'relu', in_dim, out_dim), nn.ReLU()) for in_dim, out_dim in mitt.pairwise(self.dims) ) self.model = nn.Sequential(*modules)
def _make_q_model(in_size, out_size): return nn.Sequential( make_module('linear', 'relu', in_size, 512), nn.ReLU(), make_module('linear', 'linear', 512, out_size), )