예제 #1
0
class TorchCustomModel(TorchModelV2, nn.Module):
    """Example of a PyTorch custom model that just delegates to a fc-net."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs,
                                       model_config, name)

    def forward(self, input_dict, state, seq_lens):
        input_dict["obs"] = input_dict["obs"].float()
        fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)
        return fc_out, []

    def value_function(self):
        return torch.reshape(self.torch_sub_model.value_function(), [-1])
예제 #2
0
class TorchParametricActionsModel(DQNTorchModel, nn.Module):
    """PyTorch version of above ParametricActionsModel."""
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 true_obs_shape=(4, ),
                 action_embed_size=2,
                 **kw):
        nn.Module.__init__(self)
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kw)

        self.action_embed_model = TorchFC(Box(-1, 1, shape=true_obs_shape),
                                          action_space, action_embed_size,
                                          model_config, name + "_action_embed")

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        avail_actions = input_dict["obs"]["avail_actions"]
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the predicted action embedding
        action_embed, _ = self.action_embed_model(
            {"obs": input_dict["obs"]["cart"]})

        # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
        # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
        intent_vector = torch.unsqueeze(action_embed, 1)

        # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
        action_logits = torch.sum(avail_actions * intent_vector, dim=2)

        # Mask out invalid actions (use -LARGE_INTEGER to tag invalid).
        # These are then recognized by the EpsilonGreedy exploration component
        # as invalid actions that are not to be chosen.
        inf_mask = torch.clamp(torch.log(action_mask), -float(LARGE_INTEGER),
                               float("inf"))
        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()
예제 #3
0
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 true_obs_shape=(4, ),
                 action_embed_size=2,
                 **kw):
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kw)

        self.action_embed_model = TorchFC(
            Box(-1, 1, shape=true_obs_shape),
            action_space,
            action_embed_size,
            model_config,
            name + "_action_embed",
        )
예제 #4
0
 def init_scheduler(self, action_space, obs_space):
     return self.class_to_test(
         action_space=action_space,
         framework="torch",
         initial_temperature=self.initial_temperature,
         final_temperature=self.final_temperature,
         temperature_timesteps=self.temperature_timesteps,
         temperature_schedule=self.temperature_schedule,
         policy_config={},
         num_workers=0,
         worker_index=0,
         model=FullyConnectedNetwork(
             obs_space=obs_space,
             action_space=action_space,
             num_outputs=action_space.n,
             name="fc",
             model_config=MODEL_DEFAULTS
         )
     )
예제 #5
0
class Agent(TorchModelV2, nn.Module):
  """PyTorch custom model that flattens the input to 1d and delegates to a fc-net."""

  def __init__(self, obs_space, action_space, num_outputs, model_config, name):
    self.custom_model_config = config

    # Reshape obs to vector and convert to float
    volume = np.prod(obs_space.shape)
    space = np.zeros(volume)
    flat_observation_space = spaces.Box(low=self.custom_model_config["observation_min"],
                                        high=self.custom_model_config["observation_max"],
                                        shape=space.shape, dtype=np.float32)

    # TODO: Transform to output of any other PyTorch and pass new shape to model.

    # Create default model (for RL)
    TorchModelV2.__init__(self, flat_observation_space, action_space, num_outputs, model_config, name)
    nn.Module.__init__(self)
    self.torch_sub_model = TorchFC(flat_observation_space, action_space, num_outputs, model_config, name)

  def forward(self, input_dict, state, seq_lens):
    # flatten
    obs_4d = input_dict["obs"].float()
    volume = np.prod(obs_4d.shape[1:])  # calculate volume as vector excl. batch dim
    obs_3d_shape = [obs_4d.shape[0], volume]  # [batch size, volume]
    obs_3d = torch.reshape(obs_4d, obs_3d_shape)

    # print('AGENT: OBS STATS: ', obs_3d.shape, obs_3d.min(), obs_3d.max())
    input_dict["obs"] = obs_3d

    # print(input_dict["obs"])

    # TODO: forward() any other PyTorch modules here, pass result to RL algo

    # Defer to default FC model
    fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)

    return fc_out, []

  def value_function(self):
    return torch.reshape(self.torch_sub_model.value_function(), [-1])
예제 #6
0
class TorchActionMaskModel(TorchModelV2, nn.Module):
    """PyTorch version of above ActionMaskingModel."""
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert isinstance(orig_space, Dict) and \
            "action_mask" in orig_space.spaces and \
            "observations" in orig_space.spaces

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name, **kwargs)
        nn.Module.__init__(self)

        self.internal_model = TorchFC(orig_space["observations"], action_space,
                                      num_outputs, model_config,
                                      name + "_internal")

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model(
            {"obs": input_dict["obs"]["observations"]})

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        masked_logits = logits + inf_mask

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        return self.internal_model.value_function()
예제 #7
0
class TorchCustomModel(TorchModelV2, nn.Module):
    """Example of a PyTorch custom model that just delegates to a fc-net."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, custom_input_space, action_space,
                              num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(custom_input_space, action_space,
                                       num_outputs, model_config, name)
        prev_safe_layer_size = int(np.product(custom_input_space.shape))
        vf_layers = []
        activation = model_config.get("fcnet_activation")
        hiddens = [32]
        for size in hiddens:
            vf_layers.append(
                SlimFC(in_size=prev_safe_layer_size,
                       out_size=size,
                       activation_fn=activation,
                       initializer=normc_initializer(1.0)))
            prev_safe_layer_size = size
        vf_layers.append(
            SlimFC(in_size=prev_safe_layer_size,
                   out_size=1,
                   initializer=normc_initializer(0.01),
                   activation_fn=None))
        self.safe_branch_separate = nn.Sequential(*vf_layers)
        self.last_in = None

    def forward(self, input_dict, state, seq_lens):
        input_dict["obs"] = input_dict["obs"].float(
        )[:, -2:]  # takes the last 2 values (delta_x, delta_v)
        fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)
        self.last_in = input_dict["obs"]
        return fc_out, []

    def value_function(self):
        value = torch.reshape(self.torch_sub_model.value_function(), [-1])
        safety = torch.reshape(
            self.safe_branch_separate(self.last_in).squeeze(1), [-1])
        return value + safety
예제 #8
0
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config,
        name: str,
        **customized_model_kwargs
    ):
        super(CustomFCModel, self).__init__(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config,
            name=name,
        )
        nn.Module.__init__(self)

        if "social_vehicle_config" in model_config["custom_model_config"]:
            social_vehicle_config = model_config["custom_model_config"][
                "social_vehicle_config"
            ]
        else:
            social_vehicle_config = customized_model_kwargs["social_vehicle_config"]

        social_vehicle_encoder_config = social_vehicle_config["encoder"]
        social_feature_encoder_class = social_vehicle_encoder_config[
            "social_feature_encoder_class"
        ]
        social_feature_encoder_params = social_vehicle_encoder_config[
            "social_feature_encoder_params"
        ]
        self.social_feature_encoder = (
            social_feature_encoder_class(**social_feature_encoder_params)
            if social_feature_encoder_class
            else None
        )

        self.model = TorchFCNet(
            obs_space, action_space, num_outputs, model_config, name
        )
예제 #9
0
class TorchCustomModel(TorchModelV2, nn.Module):
    """Example of a PyTorch custom model that just delegates to a fc-net."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(custom_input_space, action_space,
                                       num_outputs, model_config, name)
        self.torch_sub_model._logits = torch.nn.Sequential(
            self.torch_sub_model._logits,
            torch.nn.Hardtanh(min_val=-3, max_val=3))

    def forward(self, input_dict, state, seq_lens):
        input_dict["obs"] = input_dict["obs"].float()[:, -2:]
        fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)
        return fc_out, []

    def value_function(self):
        return torch.reshape(self.torch_sub_model.value_function(), [-1])
예제 #10
0
class DQNYanivActionMaskModel(DQNTorchModel):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kwargs):
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kwargs)
        true_obs_space = Box(low=0,
                             high=1,
                             shape=obs_space.original_space["state"].shape,
                             dtype=int)
        self.action_model = TorchFC(true_obs_space, action_space, num_outputs,
                                    model_config, name)

    def forward(self, input_dict, state, seq_lens):
        action_mask = input_dict["obs"]["action_mask"]
        action_logits, _ = self.action_model(
            {"obs": input_dict["obs"]["state"]})
        inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
        return action_logits + inf_mask, state

    def value_function(self):
        return torch.reshape(self.action_model.value_function(), [-1])
예제 #11
0
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert isinstance(orig_space, Dict) and \
            "action_mask" in orig_space.spaces and \
            "observations" in orig_space.spaces

        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name, **kwargs)
        nn.Module.__init__(self)

        self.internal_model = TorchFC(orig_space["observations"], action_space,
                                      num_outputs, model_config,
                                      name + "_internal")
예제 #12
0
class YanivActionMaskModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        true_obs_space = Box(low=0, high=1, shape=(266, ), dtype=int)
        self.action_model = TorchFC(true_obs_space, action_space, num_outputs,
                                    model_config, name)

    def forward(self, input_dict, state, seq_lens):
        action_mask = input_dict["obs"]["action_mask"]

        action_logits, _ = self.action_model(
            {"obs": input_dict["obs"]["state"]})

        inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)

        return action_logits + inf_mask, state

    def value_function(self):
        return torch.reshape(self.action_model.value_function(), [-1])
예제 #13
0
class CustomFC(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs,
                                       model_config, name)

    def forward(self, input_dict, state, seq_lens):
        fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)
        return fc_out, []

    def value_function(self):
        return torch.reshape(self.torch_sub_model.value_function(), [-1])

    def sum_params(self):
        s = 0
        for p in self.parameters():
            s += p.sum()
        return s.item()
예제 #14
0
class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
    """Multi-agent model that implements a centralized VF."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        # Base of the model
        self.model = TorchFC(obs_space, action_space, num_outputs,
                             model_config, name)

        # Central VF maps (obs, opp_obs, opp_act) -> vf_pred
        input_size = 6 + 6 + 2  # obs + opp_obs + opp_act
        self.central_vf = nn.Sequential(
            SlimFC(input_size, 16, activation_fn=nn.Tanh),
            SlimFC(16, 1),
        )

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        model_out, _ = self.model(input_dict, state, seq_lens)
        return model_out, []

    def central_value_function(self, obs, opponent_obs, opponent_actions):
        input_ = torch.cat(
            [
                obs,
                opponent_obs,
                torch.nn.functional.one_hot(opponent_actions.long(),
                                            2).float(),
            ],
            1,
        )
        return torch.reshape(self.central_vf(input_), [-1])

    @override(ModelV2)
    def value_function(self):
        return self.model.value_function()  # not used
예제 #15
0
class TorchRepeatedSpyModel(TorchModelV2, nn.Module):
    capture_index = 0

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.fc = FullyConnectedNetwork(
            obs_space.original_space.child_space["location"], action_space,
            num_outputs, model_config, name)

    def forward(self, input_dict, state, seq_lens):
        ray.experimental.internal_kv._internal_kv_put(
            "torch_rspy_in_{}".format(TorchRepeatedSpyModel.capture_index),
            pickle.dumps(input_dict["obs"].unbatch_all()),
            overwrite=True)
        TorchRepeatedSpyModel.capture_index += 1
        return self.fc({"obs": input_dict["obs"].values["location"][:, 0]},
                       state, seq_lens)

    def value_function(self):
        return self.fc.value_function()
예제 #16
0
class TorchSpyModel(TorchModelV2, nn.Module):
    capture_index = 0

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)
        self.fc = FullyConnectedNetwork(
            obs_space.original_space["sensors"].spaces["position"],
            action_space,
            num_outputs,
            model_config,
            name,
        )

    def forward(self, input_dict, state, seq_lens):
        pos = input_dict["obs"]["sensors"]["position"].detach().cpu().numpy()
        front_cam = input_dict["obs"]["sensors"]["front_cam"][0].detach().cpu().numpy()
        task = (
            input_dict["obs"]["inner_state"]["job_status"]["task"]
            .detach()
            .cpu()
            .numpy()
        )
        ray.experimental.internal_kv._internal_kv_put(
            "torch_spy_in_{}".format(TorchSpyModel.capture_index),
            pickle.dumps((pos, front_cam, task)),
            overwrite=True,
        )
        TorchSpyModel.capture_index += 1
        return self.fc(
            {"obs": input_dict["obs"]["sensors"]["position"]}, state, seq_lens
        )

    def value_function(self):
        return self.fc.value_function()
예제 #17
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.obs_sizes = _get_size(obs_space.spaces[0])
        self.n_players = len(obs_space.spaces)
        self.n_actions = action_space.spaces[0].n

        os = []
        for pl in range(self.n_players):
            os.append(flatten_space(obs_space.spaces[pl]))

        self.pl_models = {
            pl: FullyConnectedNetwork(os[pl], action_space.spaces[pl],
                                      action_space.spaces[pl].n, model_config,
                                      name)
            for pl in range(self.n_players)
        }

        # Set models as attributes to obtain parameters
        for pl in range(self.n_players):
            setattr(self, "model_{}".format(pl), self.pl_models[pl])
예제 #18
0
class CustomFCModel(TorchModelV2, nn.Module):
    """Example of interpreting repeated observations."""

    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config,
        name: str,
        **customized_model_kwargs
    ):
        super(CustomFCModel, self).__init__(
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config,
            name=name,
        )
        nn.Module.__init__(self)

        if "social_vehicle_config" in model_config["custom_model_config"]:
            social_vehicle_config = model_config["custom_model_config"][
                "social_vehicle_config"
            ]
        else:
            social_vehicle_config = customized_model_kwargs["social_vehicle_config"]

        social_vehicle_encoder_config = social_vehicle_config["encoder"]
        social_feature_encoder_class = social_vehicle_encoder_config[
            "social_feature_encoder_class"
        ]
        social_feature_encoder_params = social_vehicle_encoder_config[
            "social_feature_encoder_params"
        ]
        self.social_feature_encoder = (
            social_feature_encoder_class(**social_feature_encoder_params)
            if social_feature_encoder_class
            else None
        )

        self.model = TorchFCNet(
            obs_space, action_space, num_outputs, model_config, name
        )

    def forward(self, input_dict, state, seq_lens):
        low_dim_state = input_dict["obs"]["low_dim_states"]
        social_vehicles_state = input_dict["obs"]["social_vehicles"]

        social_feature = []
        if self.social_feature_encoder is not None:
            social_feature, _ = self.social_feature_encoder(social_vehicles_state)
        else:
            social_feature = [e.reshape(1, -1) for e in social_vehicles_state]

        input_dict["obs"]["social_vehicles"] = (
            torch.cat(social_feature, 0) if len(social_feature) > 0 else []
        )
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
예제 #19
0
 def __init__(self, obs_space, num_outputs, options):
     TorchModel.__init__(self, obs_space, num_outputs, options)
     self.fc = FullyConnectedNetwork(
         obs_space.original_space.spaces["sensors"].spaces["position"],
         num_outputs, options)
예제 #20
0
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, custom_input_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.torch_sub_model = TorchFC(custom_input_space, action_space, num_outputs, model_config, name)
예제 #21
0
def make_model_and_action_dist(policy, obs_space, action_space, config):
    """create model neural network"""
    policy.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))
    policy.log_stats = config["log_stats"]  # flag to log statistics
    if policy.log_stats:
        policy.stats_dict = {}
        policy.stats_fn = config["stats_fn"]

    # Keys of the observation space that must be used at train and test time ('signal' and 'mask' will be excluded
    # from the actual obs space)
    policy.train_obs_keys = config["train_obs_keys"]
    policy.test_obs_keys = config["test_obs_keys"]

    # Check whether policy observation space is inside a Tuple space
    policy.requires_tupling = False
    if isinstance(action_space, Tuple) and len(action_space.spaces) == 1:
        policy.action_space = action_space.spaces[0]
        action_space = action_space.spaces[0]
        policy.requires_tupling = True
    if not isinstance(action_space, Discrete):
        raise UnsupportedSpaceException(
            "Action space {} is not supported for DQN.".format(action_space))

    # Get real observation space
    if isinstance(obs_space, Box):
        assert hasattr(obs_space, "original_space"), "Invalid observation space"
        obs_space = obs_space.original_space
        if isinstance(obs_space, Tuple):
            obs_space = obs_space.spaces[0]
    assert isinstance(obs_space, Dict), "Invalid observation space"
    policy.has_action_mask = "action_mask" in obs_space.spaces
    assert all([k in obs_space.spaces for k in policy.train_obs_keys]), "Invalid train keys specification"
    assert all([k in obs_space.spaces for k in policy.test_obs_keys]), "Invalid test keys specification"

    # Get observation space used for training
    if config["train_obs_space"] is None:
        train_obs_space = obs_space
    else:
        train_obs_space = config["train_obs_space"]
        if isinstance(train_obs_space, Box):
            assert hasattr(train_obs_space, "original_space"), "Invalid observation space"
            train_obs_space = train_obs_space.original_space
            if isinstance(train_obs_space, Tuple):
                train_obs_space = train_obs_space.spaces[0]

    # Obs spaces used for training and testing
    sp = Dict({
        k: obs_space.spaces[k]
        for k in policy.test_obs_keys
    })

    policy.real_test_obs_space = flatten_space(sp)
    policy.real_test_obs_space.original_space = sp
    model_space = Dict({
        k: obs_space.spaces[k]
        for k in policy.test_obs_keys if k != "signal" and k != "action_mask"
    })


    sp = Dict({
        k: train_obs_space.spaces[k]
        for k in policy.train_obs_keys
    })
    policy.real_train_obs_space = flatten_space(sp)
    policy.real_train_obs_space.original_space = sp
    policy.n_actions = action_space.n
    def update_target():
        pass

    policy.update_target = update_target
    model = FullyConnectedNetwork(flatten_space(model_space), action_space, action_space.n, name="FcNet",
                                 model_config=config['model']).to(policy.device)
    return model, ModelCatalog.get_action_dist(action_space, config, framework='torch')
예제 #22
0
class TorchParametricActionsModel(DQNTorchModel):
    """PyTorch version of above ParametricActionsModel."""
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        action_embed_size=2,  # Dimensionality of sentence embeddings  TODO don't make this hard-coded
        **kw):
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kw)

        self.true_obs_preprocessor = DictFlatteningPreprocessor(
            obs_space.original_space["true_obs"])
        self.action_embed_model = TorchFC(
            Box(-10, 10, self.true_obs_preprocessor.shape), action_space,
            action_embed_size, model_config, name + "_action_embed")

    @staticmethod
    def make_obs_space(embed_dim=768,
                       max_steps=None,
                       max_utterances=5,
                       max_command_length=5,
                       max_variables=10,
                       max_actions=10,
                       **kwargs):
        true_obs = {
            'dialog_history':
            Repeated(Dict({
                'sender': Discrete(3),
                'utterance': Box(-10, 10, shape=(embed_dim, ))
            }),
                     max_len=max_utterances),
            'partial_command':
            Repeated(Box(-10, 10, shape=(embed_dim, )),
                     max_len=max_command_length),
            'variables':
            Repeated(Box(-10, 10, shape=(embed_dim, )), max_len=max_variables),
        }
        if max_steps:
            true_obs['steps'] = Discrete(max_steps)

        # return Dict(true_obs) For calculating true_obs_shsape

        return Dict({
            "true_obs":
            Dict(true_obs),
            '_action_mask':
            MultiDiscrete([2 for _ in range(max_actions)]),
            '_action_embeds':
            Box(-10, 10, shape=(max_actions, embed_dim)),
        })

    @staticmethod
    def make_action_space(max_actions=10, **kwargs):
        return Discrete(max_actions)

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        avail_actions = input_dict['obs']["_action_embeds"]
        action_mask = input_dict["obs"]["_action_mask"]

        import pdb
        pdb.set_trace()
        true_obs = input_dict["obs"]["true_obs"]
        true_obs = self.true_obs_preprocessor.transform(true_obs)
        # Compute the predicted action embedding
        action_embed, _ = self.action_embed_model({"obs": true_obs})

        # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
        # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
        intent_vector = torch.unsqueeze(action_embed, 1)

        # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
        action_logits = torch.sum(avail_actions * intent_vector, dim=2)

        # Mask out invalid actions (use -inf to tag invalid).
        # These are then recognized by the EpsilonGreedy exploration component
        # as invalid actions that are not to be chosen.
        inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)

        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()
예제 #23
0
class TorchCustomLossModel(TorchModelV2, nn.Module):
    """PyTorch version of the CustomLossModel above."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, input_files):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        nn.Module.__init__(self)

        self.input_files = input_files
        # Create a new input reader per worker.
        self.reader = JsonReader(self.input_files)
        self.fcnet = TorchFC(self.obs_space,
                             self.action_space,
                             num_outputs,
                             model_config,
                             name="fcnet")

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        # Delegate to our FCNet.
        return self.fcnet(input_dict, state, seq_lens)

    @override(ModelV2)
    def value_function(self):
        # Delegate to our FCNet.
        return self.fcnet.value_function()

    @override(ModelV2)
    def custom_loss(self, policy_loss, loss_inputs):
        """Calculates a custom loss on top of the given policy_loss(es).

        Args:
            policy_loss (List[TensorType]): The list of already calculated
                policy losses (as many as there are optimizers).
            loss_inputs (TensorStruct): Struct of np.ndarrays holding the
                entire train batch.

        Returns:
            List[TensorType]: The altered list of policy losses. In case the
                custom loss should have its own optimizer, make sure the
                returned list is one larger than the incoming policy_loss list.
                In case you simply want to mix in the custom loss into the
                already calculated policy losses, return a list of altered
                policy losses (as done in this example below).
        """
        # Get the next batch from our input files.
        batch = self.reader.next()

        # Define a secondary loss by building a graph copy with weight sharing.
        obs = restore_original_dimensions(torch.from_numpy(
            batch["obs"]).float().to(policy_loss[0].device),
                                          self.obs_space,
                                          tensorlib="torch")
        logits, _ = self.forward({"obs": obs}, [], None)

        # You can also add self-supervised losses easily by referencing tensors
        # created during _build_layers_v2(). For example, an autoencoder-style
        # loss can be added as follows:
        # ae_loss = squared_diff(
        #     loss_inputs["obs"], Decoder(self.fcnet.last_layer))
        print("FYI: You can also use these tensors: {}, ".format(loss_inputs))

        # Compute the IL loss.
        action_dist = TorchCategorical(logits, self.model_config)
        imitation_loss = torch.mean(-action_dist.logp(
            torch.from_numpy(batch["actions"]).to(policy_loss[0].device)))
        self.imitation_loss_metric = imitation_loss.item()
        self.policy_loss_metric = np.mean(
            [loss.item() for loss in policy_loss])

        # Add the imitation loss to each already calculated policy loss term.
        # Alternatively (if custom loss has its own optimizer):
        # return policy_loss + [10 * self.imitation_loss]
        return [loss_ + 10 * imitation_loss for loss_ in policy_loss]

    def metrics(self):
        return {
            "policy_loss": self.policy_loss_metric,
            "imitation_loss": self.imitation_loss_metric,
        }