Esempio n. 1
0
    def __init__(self, obs_shapes: Dict[str, Sequence[int]],
                 hidden_units: List[int], non_lin: nn.Module,
                 support_range: Tuple[int, int]):
        super().__init__(obs_shapes, hidden_units, non_lin)

        # build categorical value head
        support_set_size = support_range[1] - support_range[0] + 1
        self.perception_dict["probabilities"] = LinearOutputBlock(
            in_keys="latent",
            out_keys="probabilities",
            in_shapes=self.perception_dict["latent"].out_shapes(),
            output_units=support_set_size)

        # compute value as probability weighted sum of supports
        def _to_scalar(x: torch.Tensor) -> torch.Tensor:
            return support_to_scalar(x, support_range=support_range)

        self.perception_dict["value"] = FunctionalBlock(
            in_keys="probabilities",
            out_keys="value",
            in_shapes=self.perception_dict["probabilities"].out_shapes(),
            func=_to_scalar)

        module_init = make_module_init_normc(std=0.01)
        self.perception_dict["probabilities"].apply(module_init)

        # compile inference model
        self.net = InferenceBlock(in_keys=list(obs_shapes.keys()),
                                  out_keys=["probabilities", "value"],
                                  in_shapes=list(obs_shapes.values()),
                                  perception_blocks=self.perception_dict)
Esempio n. 2
0
def test_functional_block_single_arg():
    """ perception test """
    in_dict = build_input_dict(dims=[100, 64, 1])
    net: FunctionalBlock = FunctionalBlock(in_keys="in_key",
                                           out_keys="out_key",
                                           in_shapes=(100, 64, 1),
                                           func=torch.squeeze)
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 64)
Esempio n. 3
0
def test_functional_block_multi_arg_lambda():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64, 1], [100, 64, 1]])

    net: FunctionalBlock = FunctionalBlock(
        in_keys=["in_key_0", 'in_key_1'],
        out_keys="out_key",
        in_shapes=[(100, 64, 1), (100, 64, 1)],
        func=lambda in_key_0, in_key_1: torch.cat(
            (in_key_0, in_key_1), dim=-1))
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 64, 2)
Esempio n. 4
0
def test_functional_block_multi_arg_order():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64], [100, 64, 1]])

    def my_func(in_key_1, in_key_0):
        squeeze_in_1 = torch.squeeze(in_key_1, dim=-1)
        return torch.cat((in_key_0, squeeze_in_1), dim=-1)

    net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'],
                                           out_keys="out_key",
                                           in_shapes=[(100, 64), (100, 64, 1)],
                                           func=my_func)
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 128)
Esempio n. 5
0
def test_functional_block_multi_arg_multi_out():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64, 32, 1], [100, 64, 1]])

    def my_func(in_key_1, in_key_0):
        return torch.squeeze(in_key_0), torch.squeeze(in_key_1)

    net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'],
                                           out_keys=["out_key_0", 'out_key_1'],
                                           in_shapes=[(100, 64, 32, 1),
                                                      (100, 64, 1)],
                                           func=my_func)
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 64, 32)
    assert out_dict[net.out_keys[1]].shape == (100, 64)
Esempio n. 6
0
    def __init__(self, obs_shapes: Dict[str, Sequence[int]],
                 action_logits_shapes: Dict[str, Sequence[int]],
                 non_lin: Union[str, type(nn.Module)], with_mask: bool):
        nn.Module.__init__(self)
        self.obs_shapes = obs_shapes

        hidden_units, embedding_dim = 32, 7

        self.perception_dict = OrderedDict()

        # embed inventory
        # ---------------
        self.perception_dict['inventory_feat'] = DenseBlock(
            in_keys='inventory',
            out_keys='inventory_feat',
            in_shapes=self.obs_shapes['inventory'],
            hidden_units=[hidden_units],
            non_lin=non_lin)

        self.perception_dict['inventory_embed'] = LinearOutputBlock(
            in_keys='inventory_feat',
            out_keys='inventory_embed',
            in_shapes=self.perception_dict['inventory_feat'].out_shapes(),
            output_units=embedding_dim)

        # embed ordered_piece
        # ------------------_
        self.perception_dict['order_unsqueezed'] = FunctionalBlock(
            in_keys='ordered_piece',
            out_keys='order_unsqueezed',
            in_shapes=self.obs_shapes['ordered_piece'],
            func=lambda x: torch.unsqueeze(x, dim=-2))

        self.perception_dict['order_feat'] = DenseBlock(
            in_keys='order_unsqueezed',
            out_keys='order_feat',
            in_shapes=self.perception_dict['order_unsqueezed'].out_shapes(),
            hidden_units=[hidden_units],
            non_lin=non_lin)

        self.perception_dict['order_embed'] = LinearOutputBlock(
            in_keys='order_feat',
            out_keys='order_embed',
            in_shapes=self.perception_dict['order_feat'].out_shapes(),
            output_units=embedding_dim)

        # compute dot product score
        # -------------------------
        in_shapes = self.perception_dict['inventory_embed'].out_shapes()
        in_shapes += self.perception_dict['order_embed'].out_shapes()
        out_key = 'corr_score' if with_mask else 'piece_idx'
        self.perception_dict[out_key] = CorrelationBlock(
            in_keys=['inventory_embed', 'order_embed'],
            out_keys=out_key,
            in_shapes=in_shapes,
            reduce=True)

        # apply action masking
        if with_mask:
            self.perception_dict['piece_idx'] = ActionMaskingBlock(
                in_keys=['corr_score', 'inventory_mask'],
                out_keys='piece_idx',
                in_shapes=self.perception_dict['corr_score'].out_shapes() +
                [self.obs_shapes['inventory_mask']],
                num_actors=1,
                num_of_actor_actions=None)

        assert self.perception_dict['piece_idx'].out_shapes(
        )[0][0] == action_logits_shapes['piece_idx'][0]

        in_keys = ['ordered_piece', 'inventory']
        if with_mask:
            in_keys.append('inventory_mask')
        self.perception_net = InferenceBlock(
            in_keys=in_keys,
            out_keys='piece_idx',
            in_shapes=[self.obs_shapes[key] for key in in_keys],
            perception_blocks=self.perception_dict)

        # initialize model weights
        self.perception_net.apply(make_module_init_normc(1.0))
        self.perception_dict['inventory_embed'].apply(
            make_module_init_normc(0.01))
        self.perception_dict['order_embed'].apply(make_module_init_normc(0.01))