def __init__(self, obs_shapes: Dict[str, Sequence[int]], hidden_units: List[int], non_lin: nn.Module, support_range: Tuple[int, int]): super().__init__(obs_shapes, hidden_units, non_lin) # build categorical value head support_set_size = support_range[1] - support_range[0] + 1 self.perception_dict["probabilities"] = LinearOutputBlock( in_keys="latent", out_keys="probabilities", in_shapes=self.perception_dict["latent"].out_shapes(), output_units=support_set_size) # compute value as probability weighted sum of supports def _to_scalar(x: torch.Tensor) -> torch.Tensor: return support_to_scalar(x, support_range=support_range) self.perception_dict["value"] = FunctionalBlock( in_keys="probabilities", out_keys="value", in_shapes=self.perception_dict["probabilities"].out_shapes(), func=_to_scalar) module_init = make_module_init_normc(std=0.01) self.perception_dict["probabilities"].apply(module_init) # compile inference model self.net = InferenceBlock(in_keys=list(obs_shapes.keys()), out_keys=["probabilities", "value"], in_shapes=list(obs_shapes.values()), perception_blocks=self.perception_dict)
def test_functional_block_single_arg(): """ perception test """ in_dict = build_input_dict(dims=[100, 64, 1]) net: FunctionalBlock = FunctionalBlock(in_keys="in_key", out_keys="out_key", in_shapes=(100, 64, 1), func=torch.squeeze) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 64)
def test_functional_block_multi_arg_lambda(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64, 1], [100, 64, 1]]) net: FunctionalBlock = FunctionalBlock( in_keys=["in_key_0", 'in_key_1'], out_keys="out_key", in_shapes=[(100, 64, 1), (100, 64, 1)], func=lambda in_key_0, in_key_1: torch.cat( (in_key_0, in_key_1), dim=-1)) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 64, 2)
def test_functional_block_multi_arg_order(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64], [100, 64, 1]]) def my_func(in_key_1, in_key_0): squeeze_in_1 = torch.squeeze(in_key_1, dim=-1) return torch.cat((in_key_0, squeeze_in_1), dim=-1) net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'], out_keys="out_key", in_shapes=[(100, 64), (100, 64, 1)], func=my_func) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 128)
def test_functional_block_multi_arg_multi_out(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64, 32, 1], [100, 64, 1]]) def my_func(in_key_1, in_key_0): return torch.squeeze(in_key_0), torch.squeeze(in_key_1) net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'], out_keys=["out_key_0", 'out_key_1'], in_shapes=[(100, 64, 32, 1), (100, 64, 1)], func=my_func) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 64, 32) assert out_dict[net.out_keys[1]].shape == (100, 64)
def __init__(self, obs_shapes: Dict[str, Sequence[int]], action_logits_shapes: Dict[str, Sequence[int]], non_lin: Union[str, type(nn.Module)], with_mask: bool): nn.Module.__init__(self) self.obs_shapes = obs_shapes hidden_units, embedding_dim = 32, 7 self.perception_dict = OrderedDict() # embed inventory # --------------- self.perception_dict['inventory_feat'] = DenseBlock( in_keys='inventory', out_keys='inventory_feat', in_shapes=self.obs_shapes['inventory'], hidden_units=[hidden_units], non_lin=non_lin) self.perception_dict['inventory_embed'] = LinearOutputBlock( in_keys='inventory_feat', out_keys='inventory_embed', in_shapes=self.perception_dict['inventory_feat'].out_shapes(), output_units=embedding_dim) # embed ordered_piece # ------------------_ self.perception_dict['order_unsqueezed'] = FunctionalBlock( in_keys='ordered_piece', out_keys='order_unsqueezed', in_shapes=self.obs_shapes['ordered_piece'], func=lambda x: torch.unsqueeze(x, dim=-2)) self.perception_dict['order_feat'] = DenseBlock( in_keys='order_unsqueezed', out_keys='order_feat', in_shapes=self.perception_dict['order_unsqueezed'].out_shapes(), hidden_units=[hidden_units], non_lin=non_lin) self.perception_dict['order_embed'] = LinearOutputBlock( in_keys='order_feat', out_keys='order_embed', in_shapes=self.perception_dict['order_feat'].out_shapes(), output_units=embedding_dim) # compute dot product score # ------------------------- in_shapes = self.perception_dict['inventory_embed'].out_shapes() in_shapes += self.perception_dict['order_embed'].out_shapes() out_key = 'corr_score' if with_mask else 'piece_idx' self.perception_dict[out_key] = CorrelationBlock( in_keys=['inventory_embed', 'order_embed'], out_keys=out_key, in_shapes=in_shapes, reduce=True) # apply action masking if with_mask: self.perception_dict['piece_idx'] = ActionMaskingBlock( in_keys=['corr_score', 'inventory_mask'], out_keys='piece_idx', in_shapes=self.perception_dict['corr_score'].out_shapes() + [self.obs_shapes['inventory_mask']], num_actors=1, num_of_actor_actions=None) assert self.perception_dict['piece_idx'].out_shapes( )[0][0] == action_logits_shapes['piece_idx'][0] in_keys = ['ordered_piece', 'inventory'] if with_mask: in_keys.append('inventory_mask') self.perception_net = InferenceBlock( in_keys=in_keys, out_keys='piece_idx', in_shapes=[self.obs_shapes[key] for key in in_keys], perception_blocks=self.perception_dict) # initialize model weights self.perception_net.apply(make_module_init_normc(1.0)) self.perception_dict['inventory_embed'].apply( make_module_init_normc(0.01)) self.perception_dict['order_embed'].apply(make_module_init_normc(0.01))