class TorchCustomModel(TorchModelV2, nn.Module): """Example of a PyTorch custom model that just delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): input_dict["obs"] = input_dict["obs"].float() fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) return fc_out, [] def value_function(self): return torch.reshape(self.torch_sub_model.value_function(), [-1])
class TorchParametricActionsModel(DQNTorchModel, nn.Module): """PyTorch version of above ParametricActionsModel.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4, ), action_embed_size=2, **kw): nn.Module.__init__(self) DQNTorchModel.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kw) self.action_embed_model = TorchFC(Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_action_embed") def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. avail_actions = input_dict["obs"]["avail_actions"] action_mask = input_dict["obs"]["action_mask"] # Compute the predicted action embedding action_embed, _ = self.action_embed_model( {"obs": input_dict["obs"]["cart"]}) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. intent_vector = torch.unsqueeze(action_embed, 1) # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. action_logits = torch.sum(avail_actions * intent_vector, dim=2) # Mask out invalid actions (use -LARGE_INTEGER to tag invalid). # These are then recognized by the EpsilonGreedy exploration component # as invalid actions that are not to be chosen. inf_mask = torch.clamp(torch.log(action_mask), -float(LARGE_INTEGER), float("inf")) return action_logits + inf_mask, state def value_function(self): return self.action_embed_model.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4, ), action_embed_size=2, **kw): DQNTorchModel.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kw) self.action_embed_model = TorchFC( Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_action_embed", )
def init_scheduler(self, action_space, obs_space): return self.class_to_test( action_space=action_space, framework="torch", initial_temperature=self.initial_temperature, final_temperature=self.final_temperature, temperature_timesteps=self.temperature_timesteps, temperature_schedule=self.temperature_schedule, policy_config={}, num_workers=0, worker_index=0, model=FullyConnectedNetwork( obs_space=obs_space, action_space=action_space, num_outputs=action_space.n, name="fc", model_config=MODEL_DEFAULTS ) )
class Agent(TorchModelV2, nn.Module): """PyTorch custom model that flattens the input to 1d and delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): self.custom_model_config = config # Reshape obs to vector and convert to float volume = np.prod(obs_space.shape) space = np.zeros(volume) flat_observation_space = spaces.Box(low=self.custom_model_config["observation_min"], high=self.custom_model_config["observation_max"], shape=space.shape, dtype=np.float32) # TODO: Transform to output of any other PyTorch and pass new shape to model. # Create default model (for RL) TorchModelV2.__init__(self, flat_observation_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(flat_observation_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): # flatten obs_4d = input_dict["obs"].float() volume = np.prod(obs_4d.shape[1:]) # calculate volume as vector excl. batch dim obs_3d_shape = [obs_4d.shape[0], volume] # [batch size, volume] obs_3d = torch.reshape(obs_4d, obs_3d_shape) # print('AGENT: OBS STATS: ', obs_3d.shape, obs_3d.min(), obs_3d.max()) input_dict["obs"] = obs_3d # print(input_dict["obs"]) # TODO: forward() any other PyTorch modules here, pass result to RL algo # Defer to default FC model fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) return fc_out, [] def value_function(self): return torch.reshape(self.torch_sub_model.value_function(), [-1])
class TorchActionMaskModel(TorchModelV2, nn.Module): """PyTorch version of above ActionMaskingModel.""" def __init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs, ): orig_space = getattr(obs_space, "original_space", obs_space) assert isinstance(orig_space, Dict) and \ "action_mask" in orig_space.spaces and \ "observations" in orig_space.spaces TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs) nn.Module.__init__(self) self.internal_model = TorchFC(orig_space["observations"], action_space, num_outputs, model_config, name + "_internal") def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. action_mask = input_dict["obs"]["action_mask"] # Compute the unmasked logits. logits, _ = self.internal_model( {"obs": input_dict["obs"]["observations"]}) # Convert action_mask into a [0.0 || -inf]-type mask. inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) masked_logits = logits + inf_mask # Return masked logits. return masked_logits, state def value_function(self): return self.internal_model.value_function()
class TorchCustomModel(TorchModelV2, nn.Module): """Example of a PyTorch custom model that just delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, custom_input_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(custom_input_space, action_space, num_outputs, model_config, name) prev_safe_layer_size = int(np.product(custom_input_space.shape)) vf_layers = [] activation = model_config.get("fcnet_activation") hiddens = [32] for size in hiddens: vf_layers.append( SlimFC(in_size=prev_safe_layer_size, out_size=size, activation_fn=activation, initializer=normc_initializer(1.0))) prev_safe_layer_size = size vf_layers.append( SlimFC(in_size=prev_safe_layer_size, out_size=1, initializer=normc_initializer(0.01), activation_fn=None)) self.safe_branch_separate = nn.Sequential(*vf_layers) self.last_in = None def forward(self, input_dict, state, seq_lens): input_dict["obs"] = input_dict["obs"].float( )[:, -2:] # takes the last 2 values (delta_x, delta_v) fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) self.last_in = input_dict["obs"] return fc_out, [] def value_function(self): value = torch.reshape(self.torch_sub_model.value_function(), [-1]) safety = torch.reshape( self.safe_branch_separate(self.last_in).squeeze(1), [-1]) return value + safety
def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config, name: str, **customized_model_kwargs ): super(CustomFCModel, self).__init__( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=model_config, name=name, ) nn.Module.__init__(self) if "social_vehicle_config" in model_config["custom_model_config"]: social_vehicle_config = model_config["custom_model_config"][ "social_vehicle_config" ] else: social_vehicle_config = customized_model_kwargs["social_vehicle_config"] social_vehicle_encoder_config = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder_config[ "social_feature_encoder_class" ] social_feature_encoder_params = social_vehicle_encoder_config[ "social_feature_encoder_params" ] self.social_feature_encoder = ( social_feature_encoder_class(**social_feature_encoder_params) if social_feature_encoder_class else None ) self.model = TorchFCNet( obs_space, action_space, num_outputs, model_config, name )
class TorchCustomModel(TorchModelV2, nn.Module): """Example of a PyTorch custom model that just delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(custom_input_space, action_space, num_outputs, model_config, name) self.torch_sub_model._logits = torch.nn.Sequential( self.torch_sub_model._logits, torch.nn.Hardtanh(min_val=-3, max_val=3)) def forward(self, input_dict, state, seq_lens): input_dict["obs"] = input_dict["obs"].float()[:, -2:] fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) return fc_out, [] def value_function(self): return torch.reshape(self.torch_sub_model.value_function(), [-1])
class DQNYanivActionMaskModel(DQNTorchModel): def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): DQNTorchModel.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs) true_obs_space = Box(low=0, high=1, shape=obs_space.original_space["state"].shape, dtype=int) self.action_model = TorchFC(true_obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): action_mask = input_dict["obs"]["action_mask"] action_logits, _ = self.action_model( {"obs": input_dict["obs"]["state"]}) inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX) return action_logits + inf_mask, state def value_function(self): return torch.reshape(self.action_model.value_function(), [-1])
def __init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs, ): orig_space = getattr(obs_space, "original_space", obs_space) assert isinstance(orig_space, Dict) and \ "action_mask" in orig_space.spaces and \ "observations" in orig_space.spaces TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs) nn.Module.__init__(self) self.internal_model = TorchFC(orig_space["observations"], action_space, num_outputs, model_config, name + "_internal")
class YanivActionMaskModel(TorchModelV2, nn.Module): def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) true_obs_space = Box(low=0, high=1, shape=(266, ), dtype=int) self.action_model = TorchFC(true_obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): action_mask = input_dict["obs"]["action_mask"] action_logits, _ = self.action_model( {"obs": input_dict["obs"]["state"]}) inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX) return action_logits + inf_mask, state def value_function(self): return torch.reshape(self.action_model.value_function(), [-1])
class CustomFC(TorchModelV2, nn.Module): def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) return fc_out, [] def value_function(self): return torch.reshape(self.torch_sub_model.value_function(), [-1]) def sum_params(self): s = 0 for p in self.parameters(): s += p.sum() return s.item()
class TorchCentralizedCriticModel(TorchModelV2, nn.Module): """Multi-agent model that implements a centralized VF.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) # Base of the model self.model = TorchFC(obs_space, action_space, num_outputs, model_config, name) # Central VF maps (obs, opp_obs, opp_act) -> vf_pred input_size = 6 + 6 + 2 # obs + opp_obs + opp_act self.central_vf = nn.Sequential( SlimFC(input_size, 16, activation_fn=nn.Tanh), SlimFC(16, 1), ) @override(ModelV2) def forward(self, input_dict, state, seq_lens): model_out, _ = self.model(input_dict, state, seq_lens) return model_out, [] def central_value_function(self, obs, opponent_obs, opponent_actions): input_ = torch.cat( [ obs, opponent_obs, torch.nn.functional.one_hot(opponent_actions.long(), 2).float(), ], 1, ) return torch.reshape(self.central_vf(input_), [-1]) @override(ModelV2) def value_function(self): return self.model.value_function() # not used
class TorchRepeatedSpyModel(TorchModelV2, nn.Module): capture_index = 0 def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.fc = FullyConnectedNetwork( obs_space.original_space.child_space["location"], action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): ray.experimental.internal_kv._internal_kv_put( "torch_rspy_in_{}".format(TorchRepeatedSpyModel.capture_index), pickle.dumps(input_dict["obs"].unbatch_all()), overwrite=True) TorchRepeatedSpyModel.capture_index += 1 return self.fc({"obs": input_dict["obs"].values["location"][:, 0]}, state, seq_lens) def value_function(self): return self.fc.value_function()
class TorchSpyModel(TorchModelV2, nn.Module): capture_index = 0 def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name ) nn.Module.__init__(self) self.fc = FullyConnectedNetwork( obs_space.original_space["sensors"].spaces["position"], action_space, num_outputs, model_config, name, ) def forward(self, input_dict, state, seq_lens): pos = input_dict["obs"]["sensors"]["position"].detach().cpu().numpy() front_cam = input_dict["obs"]["sensors"]["front_cam"][0].detach().cpu().numpy() task = ( input_dict["obs"]["inner_state"]["job_status"]["task"] .detach() .cpu() .numpy() ) ray.experimental.internal_kv._internal_kv_put( "torch_spy_in_{}".format(TorchSpyModel.capture_index), pickle.dumps((pos, front_cam, task)), overwrite=True, ) TorchSpyModel.capture_index += 1 return self.fc( {"obs": input_dict["obs"]["sensors"]["position"]}, state, seq_lens ) def value_function(self): return self.fc.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.obs_sizes = _get_size(obs_space.spaces[0]) self.n_players = len(obs_space.spaces) self.n_actions = action_space.spaces[0].n os = [] for pl in range(self.n_players): os.append(flatten_space(obs_space.spaces[pl])) self.pl_models = { pl: FullyConnectedNetwork(os[pl], action_space.spaces[pl], action_space.spaces[pl].n, model_config, name) for pl in range(self.n_players) } # Set models as attributes to obtain parameters for pl in range(self.n_players): setattr(self, "model_{}".format(pl), self.pl_models[pl])
class CustomFCModel(TorchModelV2, nn.Module): """Example of interpreting repeated observations.""" def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config, name: str, **customized_model_kwargs ): super(CustomFCModel, self).__init__( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=model_config, name=name, ) nn.Module.__init__(self) if "social_vehicle_config" in model_config["custom_model_config"]: social_vehicle_config = model_config["custom_model_config"][ "social_vehicle_config" ] else: social_vehicle_config = customized_model_kwargs["social_vehicle_config"] social_vehicle_encoder_config = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder_config[ "social_feature_encoder_class" ] social_feature_encoder_params = social_vehicle_encoder_config[ "social_feature_encoder_params" ] self.social_feature_encoder = ( social_feature_encoder_class(**social_feature_encoder_params) if social_feature_encoder_class else None ) self.model = TorchFCNet( obs_space, action_space, num_outputs, model_config, name ) def forward(self, input_dict, state, seq_lens): low_dim_state = input_dict["obs"]["low_dim_states"] social_vehicles_state = input_dict["obs"]["social_vehicles"] social_feature = [] if self.social_feature_encoder is not None: social_feature, _ = self.social_feature_encoder(social_vehicles_state) else: social_feature = [e.reshape(1, -1) for e in social_vehicles_state] input_dict["obs"]["social_vehicles"] = ( torch.cat(social_feature, 0) if len(social_feature) > 0 else [] ) return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
def __init__(self, obs_space, num_outputs, options): TorchModel.__init__(self, obs_space, num_outputs, options) self.fc = FullyConnectedNetwork( obs_space.original_space.spaces["sensors"].spaces["position"], num_outputs, options)
def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, custom_input_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(custom_input_space, action_space, num_outputs, model_config, name)
def make_model_and_action_dist(policy, obs_space, action_space, config): """create model neural network""" policy.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) policy.log_stats = config["log_stats"] # flag to log statistics if policy.log_stats: policy.stats_dict = {} policy.stats_fn = config["stats_fn"] # Keys of the observation space that must be used at train and test time ('signal' and 'mask' will be excluded # from the actual obs space) policy.train_obs_keys = config["train_obs_keys"] policy.test_obs_keys = config["test_obs_keys"] # Check whether policy observation space is inside a Tuple space policy.requires_tupling = False if isinstance(action_space, Tuple) and len(action_space.spaces) == 1: policy.action_space = action_space.spaces[0] action_space = action_space.spaces[0] policy.requires_tupling = True if not isinstance(action_space, Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) # Get real observation space if isinstance(obs_space, Box): assert hasattr(obs_space, "original_space"), "Invalid observation space" obs_space = obs_space.original_space if isinstance(obs_space, Tuple): obs_space = obs_space.spaces[0] assert isinstance(obs_space, Dict), "Invalid observation space" policy.has_action_mask = "action_mask" in obs_space.spaces assert all([k in obs_space.spaces for k in policy.train_obs_keys]), "Invalid train keys specification" assert all([k in obs_space.spaces for k in policy.test_obs_keys]), "Invalid test keys specification" # Get observation space used for training if config["train_obs_space"] is None: train_obs_space = obs_space else: train_obs_space = config["train_obs_space"] if isinstance(train_obs_space, Box): assert hasattr(train_obs_space, "original_space"), "Invalid observation space" train_obs_space = train_obs_space.original_space if isinstance(train_obs_space, Tuple): train_obs_space = train_obs_space.spaces[0] # Obs spaces used for training and testing sp = Dict({ k: obs_space.spaces[k] for k in policy.test_obs_keys }) policy.real_test_obs_space = flatten_space(sp) policy.real_test_obs_space.original_space = sp model_space = Dict({ k: obs_space.spaces[k] for k in policy.test_obs_keys if k != "signal" and k != "action_mask" }) sp = Dict({ k: train_obs_space.spaces[k] for k in policy.train_obs_keys }) policy.real_train_obs_space = flatten_space(sp) policy.real_train_obs_space.original_space = sp policy.n_actions = action_space.n def update_target(): pass policy.update_target = update_target model = FullyConnectedNetwork(flatten_space(model_space), action_space, action_space.n, name="FcNet", model_config=config['model']).to(policy.device) return model, ModelCatalog.get_action_dist(action_space, config, framework='torch')
class TorchParametricActionsModel(DQNTorchModel): """PyTorch version of above ParametricActionsModel.""" def __init__( self, obs_space, action_space, num_outputs, model_config, name, action_embed_size=2, # Dimensionality of sentence embeddings TODO don't make this hard-coded **kw): DQNTorchModel.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kw) self.true_obs_preprocessor = DictFlatteningPreprocessor( obs_space.original_space["true_obs"]) self.action_embed_model = TorchFC( Box(-10, 10, self.true_obs_preprocessor.shape), action_space, action_embed_size, model_config, name + "_action_embed") @staticmethod def make_obs_space(embed_dim=768, max_steps=None, max_utterances=5, max_command_length=5, max_variables=10, max_actions=10, **kwargs): true_obs = { 'dialog_history': Repeated(Dict({ 'sender': Discrete(3), 'utterance': Box(-10, 10, shape=(embed_dim, )) }), max_len=max_utterances), 'partial_command': Repeated(Box(-10, 10, shape=(embed_dim, )), max_len=max_command_length), 'variables': Repeated(Box(-10, 10, shape=(embed_dim, )), max_len=max_variables), } if max_steps: true_obs['steps'] = Discrete(max_steps) # return Dict(true_obs) For calculating true_obs_shsape return Dict({ "true_obs": Dict(true_obs), '_action_mask': MultiDiscrete([2 for _ in range(max_actions)]), '_action_embeds': Box(-10, 10, shape=(max_actions, embed_dim)), }) @staticmethod def make_action_space(max_actions=10, **kwargs): return Discrete(max_actions) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. avail_actions = input_dict['obs']["_action_embeds"] action_mask = input_dict["obs"]["_action_mask"] import pdb pdb.set_trace() true_obs = input_dict["obs"]["true_obs"] true_obs = self.true_obs_preprocessor.transform(true_obs) # Compute the predicted action embedding action_embed, _ = self.action_embed_model({"obs": true_obs}) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. intent_vector = torch.unsqueeze(action_embed, 1) # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. action_logits = torch.sum(avail_actions * intent_vector, dim=2) # Mask out invalid actions (use -inf to tag invalid). # These are then recognized by the EpsilonGreedy exploration component # as invalid actions that are not to be chosen. inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX) return action_logits + inf_mask, state def value_function(self): return self.action_embed_model.value_function()
class TorchCustomLossModel(TorchModelV2, nn.Module): """PyTorch version of the CustomLossModel above.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name, input_files): super().__init__(obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.input_files = input_files # Create a new input reader per worker. self.reader = JsonReader(self.input_files) self.fcnet = TorchFC(self.obs_space, self.action_space, num_outputs, model_config, name="fcnet") @override(ModelV2) def forward(self, input_dict, state, seq_lens): # Delegate to our FCNet. return self.fcnet(input_dict, state, seq_lens) @override(ModelV2) def value_function(self): # Delegate to our FCNet. return self.fcnet.value_function() @override(ModelV2) def custom_loss(self, policy_loss, loss_inputs): """Calculates a custom loss on top of the given policy_loss(es). Args: policy_loss (List[TensorType]): The list of already calculated policy losses (as many as there are optimizers). loss_inputs (TensorStruct): Struct of np.ndarrays holding the entire train batch. Returns: List[TensorType]: The altered list of policy losses. In case the custom loss should have its own optimizer, make sure the returned list is one larger than the incoming policy_loss list. In case you simply want to mix in the custom loss into the already calculated policy losses, return a list of altered policy losses (as done in this example below). """ # Get the next batch from our input files. batch = self.reader.next() # Define a secondary loss by building a graph copy with weight sharing. obs = restore_original_dimensions(torch.from_numpy( batch["obs"]).float().to(policy_loss[0].device), self.obs_space, tensorlib="torch") logits, _ = self.forward({"obs": obs}, [], None) # You can also add self-supervised losses easily by referencing tensors # created during _build_layers_v2(). For example, an autoencoder-style # loss can be added as follows: # ae_loss = squared_diff( # loss_inputs["obs"], Decoder(self.fcnet.last_layer)) print("FYI: You can also use these tensors: {}, ".format(loss_inputs)) # Compute the IL loss. action_dist = TorchCategorical(logits, self.model_config) imitation_loss = torch.mean(-action_dist.logp( torch.from_numpy(batch["actions"]).to(policy_loss[0].device))) self.imitation_loss_metric = imitation_loss.item() self.policy_loss_metric = np.mean( [loss.item() for loss in policy_loss]) # Add the imitation loss to each already calculated policy loss term. # Alternatively (if custom loss has its own optimizer): # return policy_loss + [10 * self.imitation_loss] return [loss_ + 10 * imitation_loss for loss_ in policy_loss] def metrics(self): return { "policy_loss": self.policy_loss_metric, "imitation_loss": self.imitation_loss_metric, }