def test_networkbody_visual(): torch.manual_seed(0) vec_obs_size = 4 obs_size = (84, 84, 3) network_settings = NetworkSettings() obs_shapes = [(vec_obs_size, ), obs_size] networkbody = NetworkBody(create_observation_specs_with_shapes(obs_shapes), network_settings) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) sample_obs = 0.1 * torch.ones((1, 84, 84, 3)) sample_vec_obs = torch.ones((1, vec_obs_size)) obs = [sample_vec_obs] + [sample_obs] for _ in range(150): encoded, _ = networkbody(obs) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: super().__init__() self._policy_specs = specs state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) self._action_flattener = ModelUtils.ActionFlattener(specs) self.inverse_model_action_prediction = torch.nn.Sequential( LinearEncoder(2 * settings.encoding_size, 1, 256), linear_layer(256, self._action_flattener.flattened_size), ) self.forward_model_next_state_prediction = torch.nn.Sequential( LinearEncoder( settings.encoding_size + self._action_flattener.flattened_size, 1, 256), linear_layer(256, settings.encoding_size), )
def test_networkbody_lstm(): torch.manual_seed(0) obs_size = 4 seq_len = 6 network_settings = NetworkSettings(memory=NetworkSettings.MemorySettings( sequence_length=seq_len, memory_size=12)) obs_shapes = [(obs_size, )] networkbody = NetworkBody(create_observation_specs_with_shapes(obs_shapes), network_settings) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) sample_obs = torch.ones((seq_len, obs_size)) for _ in range(300): encoded, _ = networkbody([sample_obs], memories=torch.ones(1, 1, 12), sequence_length=seq_len) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_vector(): torch.manual_seed(0) obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size, )] networkbody = NetworkBody( create_observation_specs_with_shapes(obs_shapes), network_settings, encoded_act_size=2, ) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) sample_obs = 0.1 * torch.ones((1, obs_size)) sample_act = 0.1 * torch.ones((1, 2)) for _ in range(300): encoded, _ = networkbody([sample_obs], sample_act) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten(): assert _enc == pytest.approx(1.0, abs=0.1)
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: super().__init__() self._action_spec = specs.action_spec state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_specs, state_encoder_settings) self._action_flattener = ActionFlattener(self._action_spec) self.inverse_model_action_encoding = torch.nn.Sequential( LinearEncoder(2 * settings.encoding_size, 1, 256)) if self._action_spec.continuous_size > 0: self.continuous_action_prediction = linear_layer( 256, self._action_spec.continuous_size) if self._action_spec.discrete_size > 0: self.discrete_action_prediction = linear_layer( 256, sum(self._action_spec.discrete_branches)) self.forward_model_next_state_prediction = torch.nn.Sequential( LinearEncoder( settings.encoding_size + self._action_flattener.flattened_size, 1, 256), linear_layer(256, settings.encoding_size), )
class RNDNetwork(torch.nn.Module): EPSILON = 1e-10 def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: super().__init__() state_encoder_settings = NetworkSettings( normalize=True, hidden_units=settings.encoding_size, num_layers=3, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: n_vis = len(self._encoder.visual_processors) hidden, _ = self._encoder.forward( vec_inputs=[ ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) ], vis_inputs=[ ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) for i in range(n_vis) ], ) self._encoder.update_normalization( torch.tensor(mini_batch["vector_obs"])) return hidden
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: super().__init__() state_encoder_settings = NetworkSettings( normalize=True, hidden_units=settings.encoding_size, num_layers=3, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings)
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: super().__init__() state_encoder_settings = settings.network_settings if state_encoder_settings.memory is not None: state_encoder_settings.memory = None logger.warning( "memory was specified in network_settings but is not supported by RND. It is being ignored." ) self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings)
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: super().__init__() self._policy_specs = specs self._use_vail = settings.use_vail self._settings = settings state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) self._action_flattener = ModelUtils.ActionFlattener(specs) encoder_input_size = settings.encoding_size if settings.use_actions: encoder_input_size += (self._action_flattener.flattened_size + 1 ) # + 1 is for done self.encoder = torch.nn.Sequential( linear_layer(encoder_input_size, settings.encoding_size), Swish(), linear_layer(settings.encoding_size, settings.encoding_size), Swish(), ) estimator_input_size = settings.encoding_size if settings.use_vail: estimator_input_size = self.z_size self._z_sigma = torch.nn.Parameter(torch.ones((self.z_size), dtype=torch.float), requires_grad=True) self._z_mu_layer = linear_layer( settings.encoding_size, self.z_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=0.1, ) self._beta = torch.nn.Parameter(torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False) self._estimator = torch.nn.Sequential( linear_layer(estimator_input_size, 1), torch.nn.Sigmoid())
class RNDNetwork(torch.nn.Module): EPSILON = 1e-10 def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: super().__init__() state_encoder_settings = NetworkSettings( normalize=True, hidden_units=settings.encoding_size, num_layers=3, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings) def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: n_obs = len(self._encoder.processors) np_obs = ObsUtil.from_buffer(mini_batch, n_obs) # Convert to tensors tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] hidden, _ = self._encoder.forward(tensor_obs) self._encoder.update_normalization(mini_batch) return hidden
class RNDNetwork(torch.nn.Module): EPSILON = 1e-10 def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: super().__init__() state_encoder_settings = settings.network_settings if state_encoder_settings.memory is not None: state_encoder_settings.memory = None logger.warning( "memory was specified in network_settings but is not supported by RND. It is being ignored." ) self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings) def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: n_obs = len(self._encoder.processors) np_obs = ObsUtil.from_buffer(mini_batch, n_obs) # Convert to tensors tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] hidden, _ = self._encoder.forward(tensor_obs) self._encoder.update_normalization(mini_batch) return hidden
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: super().__init__() self._action_spec = specs.action_spec state_encoder_settings = settings.network_settings if state_encoder_settings.memory is not None: state_encoder_settings.memory = None logger.warning( "memory was specified in network_settings but is not supported by Curiosity. It is being ignored." ) self._state_encoder = NetworkBody(specs.observation_specs, state_encoder_settings) self._action_flattener = ActionFlattener(self._action_spec) self.inverse_model_action_encoding = torch.nn.Sequential( LinearEncoder(2 * state_encoder_settings.hidden_units, 1, 256)) if self._action_spec.continuous_size > 0: self.continuous_action_prediction = linear_layer( 256, self._action_spec.continuous_size) if self._action_spec.discrete_size > 0: self.discrete_action_prediction = linear_layer( 256, sum(self._action_spec.discrete_branches)) self.forward_model_next_state_prediction = torch.nn.Sequential( LinearEncoder( state_encoder_settings.hidden_units + self._action_flattener.flattened_size, 1, 256, ), linear_layer(256, state_encoder_settings.hidden_units), )
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: super().__init__() self._use_vail = settings.use_vail self._settings = settings encoder_settings = settings.network_settings if encoder_settings.memory is not None: encoder_settings.memory = None logger.warning( "memory was specified in network_settings but is not supported by GAIL. It is being ignored." ) self._action_flattener = ActionFlattener(specs.action_spec) unencoded_size = (self._action_flattener.flattened_size + 1 if settings.use_actions else 0) # +1 is for dones self.encoder = NetworkBody(specs.observation_specs, encoder_settings, unencoded_size) estimator_input_size = encoder_settings.hidden_units if settings.use_vail: estimator_input_size = self.z_size self._z_sigma = torch.nn.Parameter(torch.ones((self.z_size), dtype=torch.float), requires_grad=True) self._z_mu_layer = linear_layer( encoder_settings.hidden_units, self.z_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=0.1, ) self._beta = torch.nn.Parameter(torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False) self._estimator = torch.nn.Sequential( linear_layer(estimator_input_size, 1, kernel_gain=0.2), torch.nn.Sigmoid())
class CuriosityNetwork(torch.nn.Module): EPSILON = 1e-10 def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: super().__init__() self._policy_specs = specs state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) self._action_flattener = ModelUtils.ActionFlattener(specs) self.inverse_model_action_prediction = torch.nn.Sequential( LinearEncoder(2 * settings.encoding_size, 1, 256), linear_layer(256, self._action_flattener.flattened_size), ) self.forward_model_next_state_prediction = torch.nn.Sequential( LinearEncoder( settings.encoding_size + self._action_flattener.flattened_size, 1, 256), linear_layer(256, settings.encoding_size), ) def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Extracts the current state embedding from a mini_batch. """ n_vis = len(self._state_encoder.visual_processors) hidden, _ = self._state_encoder.forward( vec_inputs=[ ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) ], vis_inputs=[ ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) for i in range(n_vis) ], ) return hidden def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Extracts the next state embedding from a mini_batch. """ n_vis = len(self._state_encoder.visual_processors) hidden, _ = self._state_encoder.forward( vec_inputs=[ ModelUtils.list_to_tensor(mini_batch["next_vector_in"], dtype=torch.float) ], vis_inputs=[ ModelUtils.list_to_tensor(mini_batch["next_visual_obs%d" % i], dtype=torch.float) for i in range(n_vis) ], ) return hidden def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor: """ In the continuous case, returns the predicted action. In the discrete case, returns the logits. """ inverse_model_input = torch.cat((self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1) hidden = self.inverse_model_action_prediction(inverse_model_input) if self._policy_specs.is_action_continuous(): return hidden else: branches = ModelUtils.break_into_branches( hidden, self._policy_specs.discrete_action_branches) branches = [torch.softmax(b, dim=1) for b in branches] return torch.cat(branches, dim=1) def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Uses the current state embedding and the action of the mini_batch to predict the next state embedding. """ if self._policy_specs.is_action_continuous(): action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) else: action = torch.cat( ModelUtils.actions_to_onehot( ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), self._policy_specs.discrete_action_branches, ), dim=1, ) forward_model_input = torch.cat( (self.get_current_state(mini_batch), action), dim=1) return self.forward_model_next_state_prediction(forward_model_input) def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the inverse loss for a mini_batch. Corresponds to the error on the action prediction (given the current and next state). """ predicted_action = self.predict_action(mini_batch) if self._policy_specs.is_action_continuous(): sq_difference = (ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) - predicted_action)**2 sq_difference = torch.sum(sq_difference, dim=1) return torch.mean( ModelUtils.dynamic_partition( sq_difference, ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), 2, )[1]) else: true_action = torch.cat( ModelUtils.actions_to_onehot( ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), self._policy_specs.discrete_action_branches, ), dim=1, ) cross_entropy = torch.sum( -torch.log(predicted_action + self.EPSILON) * true_action, dim=1) return torch.mean( ModelUtils.dynamic_partition( cross_entropy, ModelUtils.list_to_tensor( mini_batch["masks"], dtype=torch.float), # use masks not action_masks 2, )[1]) def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Calculates the curiosity reward for the mini_batch. Corresponds to the error between the predicted and actual next state. """ predicted_next_state = self.predict_next_state(mini_batch) target = self.get_next_state(mini_batch) sq_difference = 0.5 * (target - predicted_next_state)**2 sq_difference = torch.sum(sq_difference, dim=1) return sq_difference def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the loss for the next state prediction """ return torch.mean( ModelUtils.dynamic_partition( self.compute_reward(mini_batch), ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), 2, )[1])
class DiscriminatorNetwork(torch.nn.Module): gradient_penalty_weight = 10.0 z_size = 128 alpha = 0.0005 mutual_information = 0.5 EPSILON = 1e-7 initial_beta = 0.0 def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: super().__init__() self._policy_specs = specs self._use_vail = settings.use_vail self._settings = settings state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) self._action_flattener = ModelUtils.ActionFlattener(specs) encoder_input_size = settings.encoding_size if settings.use_actions: encoder_input_size += (self._action_flattener.flattened_size + 1 ) # + 1 is for done self.encoder = torch.nn.Sequential( linear_layer(encoder_input_size, settings.encoding_size), Swish(), linear_layer(settings.encoding_size, settings.encoding_size), Swish(), ) estimator_input_size = settings.encoding_size if settings.use_vail: estimator_input_size = self.z_size self._z_sigma = torch.nn.Parameter(torch.ones((self.z_size), dtype=torch.float), requires_grad=True) self._z_mu_layer = linear_layer( settings.encoding_size, self.z_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=0.1, ) self._beta = torch.nn.Parameter(torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False) self._estimator = torch.nn.Sequential( linear_layer(estimator_input_size, 1), torch.nn.Sigmoid()) def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Creates the action Tensor. In continuous case, corresponds to the action. In the discrete case, corresponds to the concatenation of one hot action Tensors. """ return self._action_flattener.forward( torch.as_tensor(mini_batch["actions"], dtype=torch.float)) def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Creates the observation input. """ n_vis = len(self._state_encoder.visual_encoders) hidden, _ = self._state_encoder.forward( vec_inputs=[ torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float) ], vis_inputs=[ torch.as_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) for i in range(n_vis) ], ) return hidden def compute_estimate(self, mini_batch: AgentBuffer, use_vail_noise: bool = False) -> torch.Tensor: """ Given a mini_batch, computes the estimate (How much the discriminator believes the data was sampled from the demonstration data). :param mini_batch: The AgentBuffer of data :param use_vail_noise: Only when using VAIL : If true, will sample the code, if false, will return the mean of the code. """ encoder_input = self.get_state_encoding(mini_batch) if self._settings.use_actions: actions = self.get_action_input(mini_batch) dones = torch.as_tensor(mini_batch["done"], dtype=torch.float) encoder_input = torch.cat([encoder_input, actions, dones], dim=1) hidden = self.encoder(encoder_input) z_mu: Optional[torch.Tensor] = None if self._settings.use_vail: z_mu = self._z_mu_layer(hidden) hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) estimate = self._estimator(hidden) return estimate, z_mu def compute_loss(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor: """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ total_loss = torch.zeros(1) stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate(policy_batch, use_vail_noise=True) expert_estimate, expert_mu = self.compute_estimate(expert_batch, use_vail_noise=True) stats_dict["Policy/GAIL Policy Estimate"] = ( policy_estimate.mean().detach().cpu().numpy()) stats_dict["Policy/GAIL Expert Estimate"] = ( expert_estimate.mean().detach().cpu().numpy()) discriminator_loss = -( torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean() stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu( ).numpy() total_loss += discriminator_loss if self._settings.use_vail: # KL divergence loss (encourage latent representation to be normal) kl_loss = torch.mean(-torch.sum( 1 + (self._z_sigma**2).log() - 0.5 * expert_mu**2 - 0.5 * policy_mu**2 - (self._z_sigma**2), dim=1, )) vail_loss = self._beta * (kl_loss - self.mutual_information) with torch.no_grad(): self._beta.data = torch.max( self._beta + self.alpha * (kl_loss - self.mutual_information), torch.tensor(0.0), ) total_loss += vail_loss stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy() stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() if self.gradient_penalty_weight > 0.0: total_loss += ( self.gradient_penalty_weight * self.compute_gradient_magnitude(policy_batch, expert_batch)) return total_loss, stats_dict def compute_gradient_magnitude(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor: """ Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. for off-policy. Compute gradients w.r.t randomly interpolated input. """ policy_obs = self.get_state_encoding(policy_batch) expert_obs = self.get_state_encoding(expert_batch) obs_epsilon = torch.rand(policy_obs.shape) encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(policy_batch) action_epsilon = torch.rand(policy_action.shape) policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float) expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float) dones_epsilon = torch.rand(policy_dones.shape) encoder_input = torch.cat( [ encoder_input, action_epsilon * policy_action + (1 - action_epsilon) * expert_action, dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, ], dim=1, ) hidden = self.encoder(encoder_input) if self._settings.use_vail: use_vail_noise = True z_mu = self._z_mu_layer(hidden) hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) hidden = self._estimator(hidden) estimate = torch.mean(torch.sum(hidden, dim=1)) gradient = torch.autograd.grad(estimate, encoder_input)[0] # Norm's gradient could be NaN at 0. Use our own safe_norm safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt() gradient_mag = torch.mean((safe_norm - 1)**2) return gradient_mag
class CuriosityNetwork(torch.nn.Module): EPSILON = 1e-10 def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: super().__init__() self._action_spec = specs.action_spec state_encoder_settings = NetworkSettings( normalize=False, hidden_units=settings.encoding_size, num_layers=2, vis_encode_type=EncoderType.SIMPLE, memory=None, ) self._state_encoder = NetworkBody(specs.observation_specs, state_encoder_settings) self._action_flattener = ActionFlattener(self._action_spec) self.inverse_model_action_encoding = torch.nn.Sequential( LinearEncoder(2 * settings.encoding_size, 1, 256)) if self._action_spec.continuous_size > 0: self.continuous_action_prediction = linear_layer( 256, self._action_spec.continuous_size) if self._action_spec.discrete_size > 0: self.discrete_action_prediction = linear_layer( 256, sum(self._action_spec.discrete_branches)) self.forward_model_next_state_prediction = torch.nn.Sequential( LinearEncoder( settings.encoding_size + self._action_flattener.flattened_size, 1, 256), linear_layer(256, settings.encoding_size), ) def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Extracts the current state embedding from a mini_batch. """ n_obs = len(self._state_encoder.processors) np_obs = ObsUtil.from_buffer(mini_batch, n_obs) # Convert to tensors tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] hidden, _ = self._state_encoder.forward(tensor_obs) return hidden def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Extracts the next state embedding from a mini_batch. """ n_obs = len(self._state_encoder.processors) np_obs = ObsUtil.from_buffer_next(mini_batch, n_obs) # Convert to tensors tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] hidden, _ = self._state_encoder.forward(tensor_obs) return hidden def predict_action(self, mini_batch: AgentBuffer) -> ActionPredictionTuple: """ In the continuous case, returns the predicted action. In the discrete case, returns the logits. """ inverse_model_input = torch.cat((self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1) continuous_pred = None discrete_pred = None hidden = self.inverse_model_action_encoding(inverse_model_input) if self._action_spec.continuous_size > 0: continuous_pred = self.continuous_action_prediction(hidden) if self._action_spec.discrete_size > 0: raw_discrete_pred = self.discrete_action_prediction(hidden) branches = ModelUtils.break_into_branches( raw_discrete_pred, self._action_spec.discrete_branches) branches = [torch.softmax(b, dim=1) for b in branches] discrete_pred = torch.cat(branches, dim=1) return ActionPredictionTuple(continuous_pred, discrete_pred) def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Uses the current state embedding and the action of the mini_batch to predict the next state embedding. """ actions = AgentAction.from_buffer(mini_batch) flattened_action = self._action_flattener.forward(actions) forward_model_input = torch.cat( (self.get_current_state(mini_batch), flattened_action), dim=1) return self.forward_model_next_state_prediction(forward_model_input) def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the inverse loss for a mini_batch. Corresponds to the error on the action prediction (given the current and next state). """ predicted_action = self.predict_action(mini_batch) actions = AgentAction.from_buffer(mini_batch) _inverse_loss = 0 if self._action_spec.continuous_size > 0: sq_difference = (actions.continuous_tensor - predicted_action.continuous)**2 sq_difference = torch.sum(sq_difference, dim=1) _inverse_loss += torch.mean( ModelUtils.dynamic_partition( sq_difference, ModelUtils.list_to_tensor(mini_batch[BufferKey.MASKS], dtype=torch.float), 2, )[1]) if self._action_spec.discrete_size > 0: true_action = torch.cat( ModelUtils.actions_to_onehot( actions.discrete_tensor, self._action_spec.discrete_branches), dim=1, ) cross_entropy = torch.sum( -torch.log(predicted_action.discrete + self.EPSILON) * true_action, dim=1, ) _inverse_loss += torch.mean( ModelUtils.dynamic_partition( cross_entropy, ModelUtils.list_to_tensor( mini_batch[BufferKey.MASKS], dtype=torch.float), # use masks not action_masks 2, )[1]) return _inverse_loss def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Calculates the curiosity reward for the mini_batch. Corresponds to the error between the predicted and actual next state. """ predicted_next_state = self.predict_next_state(mini_batch) target = self.get_next_state(mini_batch) sq_difference = 0.5 * (target - predicted_next_state)**2 sq_difference = torch.sum(sq_difference, dim=1) return sq_difference def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the loss for the next state prediction """ return torch.mean( ModelUtils.dynamic_partition( self.compute_reward(mini_batch), ModelUtils.list_to_tensor(mini_batch[BufferKey.MASKS], dtype=torch.float), 2, )[1])