def _step(self, time_step: TimeStep, state, calc_rewards=True): """This step is for both `rollout_step` and `train_step`. Args: time_step (TimeStep): input time_step data for ICM state (Tensor): state for ICM (previous observation) calc_rewards (bool): whether calculate rewards Returns: AlgStep: output: empty tuple () state: observation info (ICMInfo): """ feature = time_step.observation prev_action = time_step.prev_action.detach() # normalize observation for easier prediction if self._observation_normalizer is not None: feature = self._observation_normalizer.normalize(feature) if self._encoding_net is not None: feature, _ = self._encoding_net(feature) prev_feature = state forward_pred, _ = self._forward_net( inputs=[prev_feature.detach(), self._encode_action(prev_action)]) # nn.MSELoss doesn't support reducing along a dim forward_loss = 0.5 * torch.mean( math_ops.square(forward_pred - feature.detach()), dim=-1) action_pred, _ = self._inverse_net([prev_feature, feature]) if self._action_spec.is_discrete: inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')( input=action_pred, target=prev_action.to(torch.int64)) else: # nn.MSELoss doesn't support reducing along a dim inverse_loss = 0.5 * torch.mean( math_ops.square(action_pred - prev_action), dim=-1) intrinsic_reward = () if calc_rewards: intrinsic_reward = forward_loss.detach() intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgStep( output=(), state=feature, info=ICMInfo( reward=intrinsic_reward, loss=LossInfo( loss=forward_loss + inverse_loss, extra=dict( forward_loss=forward_loss, inverse_loss=inverse_loss))))
def global_norm(tensors): """Adapted from TF's version. Computes the global norm of a nest of tensors. Given a nest of tensors `tensors`, this function returns the global norm of all tensors in `tensors`. The global norm is computed as: `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))` Any entries in `tensors` that are of type None are ignored. Args: tensors (nested Tensor): a nest of tensors Returns: norm (Tensor): a scalar tensor """ assert alf.nest.is_nested(tensors), "tensors must be a nest!" tensors = alf.nest.flatten(tensors) if not tensors: return torch.zeros((), dtype=torch.float32) return torch.sqrt( sum([ math_ops.square(torch.norm(torch.reshape(t, [-1]))) for t in tensors if t is not None ]))
def _estimate_mi(i, batch): estimated_pmi = mi_estimator.calc_pmi(batch['X'], batch['Y'], batch['Y_dist']) batch_size = estimated_pmi.shape[0] x, y, z = batch['x'], batch['y'], batch['z'] pmi = 0.5 * (math_ops.square(y - z - z * z) / (e * e + z * z) - math_ops.square(y - z - x * z) / (e * e) + torch.log(1 + (z / e)**2)) pmi = pmi * (z > 0).to(torch.float32) pmi = torch.sum(pmi, dim=-1) pmi_rmse = torch.sqrt( torch.mean(math_ops.square(pmi - estimated_pmi))) estimated_mi = estimated_pmi.mean(dim=0) var = torch.var(estimated_pmi, dim=0, unbiased=False) estimated_mi = float(estimated_mi) logging.info("%s estimated_mi=%s std=%s pmi_rmse=%s" % ( i, estimated_mi, math.sqrt(var / batch_size), float(pmi_rmse))) return estimated_mi
def _summarize_all(path, t, m2, m=None): if path: path += "." _summary(path + "tensor.batch_min", _reduce_along_batch_dims(t, m2, torch.min)) _summary(path + "tensor.batch_max", _reduce_along_batch_dims(t, m2, torch.max)) if m is not None: _summary(path + "mean", m) _summary(path + "var", m2 - math_ops.square(m)) else: _summary(path + "second_moment", m2)
def _verify_normalization(weights, normalized_tensor, eps): tensors_mean = torch.sum(weights * self._tensors) tensors_var = torch.sum( weights * math_ops.square(self._tensors - tensors_mean)) target_normalized_tensor = alf.layers.normalize_along_batch_dims( self._tensors[-1], tensors_mean, tensors_var, variance_epsilon=eps) self.assertTensorClose(normalized_tensor, target_normalized_tensor, epsilon=1e-4)
def _normalize(m2, t, m=None): # in some extreme cases, due to floating errors, var might be a very # large negative value (close to 0) if m is not None: var = torch.relu(m2 - math_ops.square(m)) else: var = m2 m = torch.zeros_like(m2) t = alf.layers.normalize_along_batch_dims( t, m, var, variance_epsilon=self._variance_epsilon) if clip_value > 0: t = torch.clamp(t, -clip_value, clip_value) return t
def _step(self, time_step: TimeStep, state, calc_rewards=True): """ Args: time_step (TimeStep): input time step data, where the observation is skill-augmened observation. The skill should be a one-hot vector. state (Tensor): state for DIAYN (previous skill) which should be a one-hot vector. calc_rewards (bool): if False, only return the losses. Returns: AlgStep: output: empty tuple () state: skill info (DIAYNInfo): """ observations_aug = time_step.observation step_type = time_step.step_type observation, skill = observations_aug prev_skill = state.detach() # normalize observation for easier prediction if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) if self._encoding_net is not None: feature, _ = self._encoding_net(observation) skill_pred, _ = self._discriminator_net(feature) if self._skill_spec.is_discrete: loss = torch.nn.CrossEntropyLoss(reduction='none')( input=skill_pred, target=torch.argmax(prev_skill, dim=-1)) else: # nn.MSELoss doesn't support reducing along a dim loss = torch.sum(math_ops.square(skill_pred - prev_skill), dim=-1) valid_masks = (step_type != to_tensor(StepType.FIRST)).to( torch.float32) loss *= valid_masks intrinsic_reward = () if calc_rewards: intrinsic_reward = -loss.detach() intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgStep( output=(), state=skill, info=DIAYNInfo(reward=intrinsic_reward, loss=loss))
def _summarize_all(path, t, m2, m): if path: path += "." spec = TensorSpec.from_tensor(m2 or m) _summary(path + "tensor.batch_min", _reduce_along_batch_dims(t, spec, torch.min)) _summary(path + "tensor.batch_max", _reduce_along_batch_dims(t, spec, torch.max)) if m is not None: _summary(path + "mean", m) if m2 is not None: _summary(path + "var", m2 - math_ops.square(m)) elif m2 is not None: _summary(path + "second_moment", m2)
def _sampling_forward(self, inputs): """Encode the data into latent space then do sampling. Args: inputs (nested Tensor): if a prior network is provided, this is a tuple of ``(prior_input, new_observation)``. Returns: tuple: - z (Tensor): ``z`` is a tensor of shape (``B``, ``z_dim``). - kl_loss (Tensor): ``kl_loss`` is a tensor of shape (``B``,). """ if self._z_prior_network: prior_input, new_obs = inputs prior_z_mean_and_log_var, _ = self._z_prior_network(prior_input) prior_z_mean = prior_z_mean_and_log_var[..., :self._z_dim] prior_z_log_var = prior_z_mean_and_log_var[..., self._z_dim:] inputs = (prior_input, new_obs, prior_z_mean_and_log_var) latents, _ = self._preprocess_network(inputs) z_mean = self._z_mean(latents) z_log_var = self._z_log_var(latents) if self._z_prior_network: kl_div_loss = math_ops.square(z_mean) / torch.exp(prior_z_log_var) + \ torch.exp(z_log_var) - z_log_var - 1.0 z_mean = z_mean + prior_z_mean z_log_var = z_log_var + prior_z_log_var else: kl_div_loss = math_ops.square(z_mean) + torch.exp( z_log_var) - 1.0 - z_log_var kl_div_loss = 0.5 * torch.sum(kl_div_loss, dim=-1) # reparameterization sampling: z = u + var ** 0.5 * eps eps = torch.randn(z_mean.shape) z = z_mean + torch.exp(z_log_var * 0.5) * eps return z, kl_div_loss
def _step(self, time_step: TimeStep, state, calc_rewards=True): """ Args: time_step (TimeStep): input time_step data state (tuple): empty tuple () calc_rewards (bool): whether calculate rewards Returns: AlgStep: output: empty tuple () state: empty tuple () info: ICMInfo """ observation = time_step.observation if self._keep_stacked_frames > 0: # Assuming stacking in the first dim, we only keep the last frames. observation = observation[:, -self._keep_stacked_frames:, ...] if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) if self._encoder_net is not None: with torch.no_grad(): observation, _ = self._encoder_net(observation) pred_embedding, _ = self._predictor_net(observation) with torch.no_grad(): target_embedding, _ = self._target_net(observation) loss = torch.sum(math_ops.square(pred_embedding - target_embedding), dim=-1) intrinsic_reward = () if calc_rewards: intrinsic_reward = loss.detach() if self._reward_normalizer: intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward, clip_value=self._reward_clip_value) return AlgStep(output=(), state=(), info=ICMInfo(reward=intrinsic_reward, loss=LossInfo(loss=loss)))
def __test_conditional_mi_estimator(self, estimator='ML', switch_xy=False, use_default_model=True, eps=0.02, dim=2): """Estimate the conditional mutual information MI(X;Y|Z) X, Y and Z are generated by the following procedure: Z ~ N(0, 1) X|z ~ N(z, 1) if z >= 0: Y|x,z ~ N(z + xz, e^2) else: Y|x,z ~ N(0, 1) When z>0, [X, Y] ~ N([z, z+z^2], [[1, z], [z, e^2+z^2]]) MI(X;Y|z) = 0.5 * log(1+z^2/e^2) """ x_spec = [ alf.TensorSpec(shape=(dim, ), dtype=torch.float32), alf.TensorSpec(shape=(dim, ), dtype=torch.float32) ] y_spec = alf.TensorSpec(shape=(dim, ), dtype=torch.float32) if use_default_model: model = None elif estimator == 'ML': model = NetML(x_spec) else: model = NetJSD([x_spec, y_spec]) mi_estimator = MIEstimator( x_spec=x_spec, y_spec=y_spec, fc_layers=(256, 256), model=model, estimator_type=estimator, optimizer=alf.optimizers.AdamTF(lr=2e-4)) z = torch.randn(10000, ) e = 0.5 mi = 0.25 * dim * torch.mean(torch.log(1 + (z / e)**2)) def _get_batch(batch_size, z=None): if z is None: z = torch.randn(batch_size, dim) x_dist = DiagMultivariateNormal(loc=z, scale=torch.ones_like(z)) mask = (z > 0).to(torch.float32) y_dist = DiagMultivariateNormal( loc=(z + z * z) * mask, scale=1 - mask + mask * torch.sqrt(e * e + z * z)) x = x_dist.sample() y = (z + x * z) * mask + (1 - mask + e * mask) * torch.randn( batch_size, dim) if not switch_xy: X = [z, x] Y = y Y_dist = y_dist else: X = [z, y] Y = x Y_dist = x_dist return dict(x=x, y=y, z=z, X=X, Y=Y, Y_dist=Y_dist) def _estimate_mi(i, batch): estimated_pmi = mi_estimator.calc_pmi(batch['X'], batch['Y'], batch['Y_dist']) batch_size = estimated_pmi.shape[0] x, y, z = batch['x'], batch['y'], batch['z'] pmi = 0.5 * (math_ops.square(y - z - z * z) / (e * e + z * z) - math_ops.square(y - z - x * z) / (e * e) + torch.log(1 + (z / e)**2)) pmi = pmi * (z > 0).to(torch.float32) pmi = torch.sum(pmi, dim=-1) pmi_rmse = torch.sqrt( torch.mean(math_ops.square(pmi - estimated_pmi))) estimated_mi = estimated_pmi.mean(dim=0) var = torch.var(estimated_pmi, dim=0, unbiased=False) estimated_mi = float(estimated_mi) logging.info("%s estimated_mi=%s std=%s pmi_rmse=%s" % ( i, estimated_mi, math.sqrt(var / batch_size), float(pmi_rmse))) return estimated_mi batch_size = 512 info = "mi=%s estimator=%s use_default_model=%s switch_xy=%s dim=%s" % ( float(mi), estimator, use_default_model, switch_xy, dim) def _train(): batch = _get_batch(batch_size) alg_step = mi_estimator.train_step((batch['X'], batch['Y']), y_distribution=batch['Y_dist']) mi_estimator.update_with_gradient(alg_step.info) return alg_step for i in range(20000): _train() if i % 1000 == 0: batch = _get_batch(batch_size) _estimate_mi(i, batch) batch_size = 16384 batch = _get_batch(batch_size) estimated_mi = _estimate_mi(info, batch) self.assertAlmostEqual(estimated_mi, mi, delta=eps) # Set detail_reault=True to show the conditional mutual information for # different values of z detail_result = False if detail_result: for z in torch.arange(-2., 2.001, 0.125): batch = _get_batch(batch_size, z * torch.ones(batch_size, dim)) info = "z={z} mi={mi}".format( z=float(z), mi=float( 0.5 * torch.log(1 + math_ops.square(F.relu(z / e))))) _estimate_mi(info, batch) return mi, estimated_mi