Ejemplo n.º 1
0
def _make_policy_model(in_size, out_size):
    return nn.Sequential(
        make_module('linear', 'relu', in_size, 512),
        nn.ReLU(),
        make_module('linear', 'linear', 512, out_size),
        nn.LogSoftmax(dim=-1),
    )
Ejemplo n.º 2
0
def _make_v_model(in_size):
    return nn.Sequential(
        make_module('linear', 'relu', in_size, 512),
        nn.ReLU(),
        make_module('linear', 'relu', 512, 256),
        nn.ReLU(),
        make_module('linear', 'linear', 256, 1),
    )
Ejemplo n.º 3
0
    def __init__(
        self,
        space: gym.spaces.Dict,
        names: Iterable[str],
        *,
        embedding_size: int,
        layers: List[int],
    ):
        super().__init__()
        self.space = space

        num_embeddings = max(
            space['grid'].high.max() + 1,
            space['item'].high.max() + 1,
        )
        self.embedding = EmbeddingRepresentation(num_embeddings,
                                                 embedding_size)
        gv_models = [self._make_gv_model(name) for name in names]
        self.cat_representation = CatRepresentation(gv_models)
        self.fc_model: nn.Module

        if len(layers) > 0:
            dims = [self.cat_representation.dim] + layers
            linear_modules = [
                make_module('linear', 'relu', in_dim, out_dim)
                for in_dim, out_dim in mitt.pairwise(dims)
            ]
            relu_modules = [nn.ReLU() for _ in linear_modules]
            modules = mitt.interleave(linear_modules, relu_modules)
            self.fc_model = nn.Sequential(*modules)
            self._dim = dims[-1]

        else:
            self.fc_model = nn.Identity()
            self._dim = self.cat_representation.dim
Ejemplo n.º 4
0
 def __init__(self, representation: Representation, dim: int):
     super().__init__()
     self._representation = representation
     self._resize_model = nn.Sequential(
         make_module('linear', 'relu', representation.dim, dim),
         nn.ReLU(),
     )
     self._dim = dim
Ejemplo n.º 5
0
    def __init__(self, input_space: gym.spaces.Box, dims: Sequence[int]):
        super().__init__()

        checkraise(
            isinstance(input_space, gym.spaces.Box)
            and len(input_space.shape) == 1,
            TypeError,
            'input_space must be Box',
        )
        checkraise(
            len(dims) > 0,
            ValueError,
            'dims must be non-empty',
        )

        (input_dim,) = input_space.shape
        self.dims = list(itt.chain([input_dim], dims))

        modules = mitt.flatten(
            (make_module('linear', 'relu', in_dim, out_dim), nn.ReLU())
            for in_dim, out_dim in mitt.pairwise(self.dims)
        )
        self.model = nn.Sequential(*modules)
Ejemplo n.º 6
0
def _make_q_model(in_size, out_size):
    return nn.Sequential(
        make_module('linear', 'relu', in_size, 512),
        nn.ReLU(),
        make_module('linear', 'linear', 512, out_size),
    )