def test_softmax(self): logits = torch.from_numpy( np.array([ [0.5, 0.5], [1.0, 0.5], ], dtype=np.float)) mask_1 = torch.ones(size=logits.shape, dtype=torch.bool) y1 = masked_softmax(logits=logits, mask=mask_1) self.assertEqual(y1.shape, (2, 2)) self.assertAlmostEqual(y1.sum().item(), 2.0) mask_2 = torch.from_numpy(np.array([[1, 0], [1, 0]], dtype=np.bool)) y2 = masked_softmax(logits=logits, mask=mask_2) total = y2.sum(dim=0, keepdim=False) self.assertTrue(np.allclose(total, np.array([2, 0])))
def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict: data = self.parse_observations(observations) # Cast action to tensor if actions is not None: actions = torch.as_tensor(actions, dtype=torch.float, device=self.device) # SO3Vec (batches, atoms, taus, ms, 2) covariats = self.cg_model(data) # Compute invariants invariats = self.atomic_scalars(covariats) # (batches, atoms, inv_feats) # Focus focus_logits = self.phi_focus(invariats) # (batches, atoms, 1) focus_logits = focus_logits.squeeze(-1) # (batches, atoms) focus_probs = masked_softmax(focus_logits, mask=data['focus_mask']) # (batches, atoms) focus_dist = torch.distributions.Categorical(probs=focus_probs) # focus: (batches, 1) if actions is not None: focus = torch.round(actions[:, :1]).long() elif self.training: focus = focus_dist.sample().unsqueeze(-1) else: focus = torch.argmax(focus_probs, dim=-1).unsqueeze(-1) focus_oh = to_one_hot(focus, num_classes=self.observation_space.canvas_space.size, device=self.device) # (batches, atoms) focused_cov = so3_tools.select_atomic_covariats(covariats, focus_oh) # (batches, taus, ms, 2) focused_inv = so3_tools.select_atomic_invariats(invariats, focus_oh) # (batches, feats) # Element element_logits = self.phi_element(focused_inv) # (batches, zs) element_probs = masked_softmax(element_logits, mask=data['element_mask']) # (batches, zs) element_dist = torch.distributions.Categorical(probs=element_probs) # element: (batches, 1) if actions is not None: element = torch.round(actions[:, 1:2]).long() elif self.training: element = element_dist.sample().unsqueeze(-1) else: element = torch.argmax(element_probs, dim=-1).unsqueeze(-1) # Crop element offsets = self.channel_offsets.expand(len(observations), -1) # (batches, channels_per_element) indices = offsets + element * self.num_channels_per_element element_cov = so3_tools.select_taus(focused_cov, indices=indices) element_inv = self.atomic_scalars(element_cov) # (batches, inv_feats) # Distance: Gaussian mixture model # gmm_log_probs, d_mean_trans: (batches, gaussians) gmm_log_probs, d_mean_trans = self.phi_d(element_inv).split(self.num_gaussians, dim=-1) distance_mean = torch.tanh(d_mean_trans) * self.distance_half_width + self.distance_center distance_dist = GaussianMixtureModel(log_probs=gmm_log_probs, means=distance_mean, stds=torch.exp(self.distance_log_stds).clamp(1e-6)) # distance: (batches, 1) if actions is not None: distance = actions[:, 2:3] elif self.training: # Ensure that the sampled distance is > 0 distance = distance_dist.sample().clamp(0.001).unsqueeze(-1) else: distance = distance_dist.argmax().unsqueeze(-1) # Condition on distance transformed_d = distance.unsqueeze(1).unsqueeze(1).expand(-1, self.num_channels_per_element, 1, -1) transformed_d = self.pad_zeros(transformed_d) distance_so3 = SO3Vec([transformed_d]) cond_cov = self.cg_mix(element_cov, distance_so3) so3_dist = self.get_so3_distribution(a_lms=cond_cov, empty=data['empty']) # so3: (batches, 3) if actions is not None: orientation = actions[..., 3:6] elif self.training: orientation = so3_dist.sample() else: orientation = so3_dist.argmax() # Log prob log_prob_list = [ focus_dist.log_prob(focus.squeeze(-1)), element_dist.log_prob(element.squeeze(-1)), distance_dist.log_prob(distance.squeeze(-1)), so3_dist.log_prob(orientation), ] log_prob = torch.stack(log_prob_list, dim=-1).sum(dim=-1) # (batches, ) # Entropy entropy_list = [ focus_dist.entropy(), element_dist.entropy(), ] entropy = torch.stack(entropy_list, dim=-1).sum(dim=-1) # (batches, ) # Value function # atom_mask: (batches, atoms) # invariants: (batches, atoms, feats) trans_invariats = self.phi_trans(invariats) value_feats = torch.einsum( # type: ignore 'ba,baf->bf', data['value_mask'].to(self.dtype), trans_invariats) # (batches, inv_feats) value = self.phi_v(value_feats).squeeze(-1) # (batches, ) # Action response: Dict[str, Any] = {} if actions is None: actions = torch.cat([focus.float(), element.float(), distance, orientation], dim=-1) # Build correspond action in action space response['actions'] = [self.to_action_space(a, o) for a, o in zip(actions, observations)] response.update({ 'a': actions, # (batches, subactions) 'logp': log_prob, # (batches, ) 'ent': entropy, # (batches, ) 'v': value, # (batches, ) 'dists': [focus_dist, element_dist, distance_dist, so3_dist], }) return response
def step(self, observations: List[ObservationType], action: Optional[np.ndarray] = None) -> dict: # atomic_feats: n_obs x n_atoms x n_afeats # focus_mask: n_obs x n_atoms # focus_mask: n_obs x n_atoms # element_count: n_obs x n_zs atomic_feats, focus_mask, focus_mask_next, element_count, action_mask = self.make_atomic_tensors( observations) element_mask = (element_count > 0).int() # stop: this agent does not stop stop = torch.zeros(size=(len(observations), 1), dtype=torch.float, device=self.device) # latent states bag latent_bag = self.phi_beta(element_count) # latent representation of atoms and bag latent_bag_tiled = latent_bag.unsqueeze(1) # n_obs x 1 x n_zs latent_bag_tiled = latent_bag_tiled.expand( -1, self.num_atoms, -1) # n_obs x n_atoms x n_zs latent_states = torch.cat( [atomic_feats, latent_bag_tiled], dim=-1) # n_obs x n_atoms x (n_afeats + n_zs) # Focus focus_logits = self.phi_focus(latent_states) # n_obs x n_atoms x 1 focus_logits = focus_logits.squeeze(-1) # n_obs x n_atoms focus_p = masked_softmax(focus_logits, mask=focus_mask) # n_obs x n_atoms focus_dist = torch.distributions.Categorical(probs=focus_p) # Cast action to Tensor if action is not None: action = torch.as_tensor(action, device=self.device) # focus: n_obs x 1 if action is not None: focus = torch.round(action[:, 1:2]).long() elif self.training: focus = focus_dist.sample().unsqueeze(-1) else: focus = torch.argmax(focus_p, dim=-1).unsqueeze(-1) focus_oh = to_one_hot(focus, num_classes=self.num_atoms) # n_obs x n_atoms # Focused atom is a hard (one-hot) selection over atoms focused_atom = ( latent_states.transpose(1, 2) @ focus_oh[:, :, None]).squeeze( -1) # n_obs x n_latent # Element element_logits = self.phi_element(focused_atom) # n_obs x n_zs element_p = masked_softmax(element_logits, mask=element_mask) # n_obs x n_zs element_dist = torch.distributions.Categorical(probs=element_p) # element: n_obs x 1 if action is not None: element = torch.round(action[:, 2:3]).long() elif self.training: element = element_dist.sample().unsqueeze(-1) else: element = torch.argmax(element_p, dim=-1).unsqueeze(-1) element_oh = to_one_hot(element, self.num_zs) # n_obs x n_zs # Continuous variables # f: n_obs x (n_latent + n_zs) f = torch.cat([focused_atom, element_oh], dim=-1) distance_mean, angle_mean, dihedral_mean = torch.split(torch.tanh( self.phi_continuous(f)), 1, dim=-1) # Distance distance_mean = (distance_mean * self.action_width[0] / 2) + self.action_center[0] distance_dist = torch.distributions.Normal( loc=distance_mean, scale=torch.exp(1e-6 + self.log_stds[0])) # distance: n_obs x 1 if action is not None: distance = action[:, 3:4] elif self.training: # Ensure that the sampled distance is > 0 distance = distance_dist.sample().clamp(0.001) else: distance = distance_mean # Angle angle_mean = (angle_mean * self.action_width[1] / 2) + self.action_center[1] angle_dist = torch.distributions.Normal( loc=angle_mean, scale=torch.exp(1e-6 + self.log_stds[1])) # angle: n_obs x 1 if action is not None: angle = action[:, 4:5] elif self.training: angle = angle_dist.sample() else: angle = angle_mean # Dihedral dihedral_mean = (dihedral_mean * self.action_width[2] / 2) + self.action_center[2] dihedral_dist = torch.distributions.Normal( loc=dihedral_mean, scale=torch.exp(1e-6 + self.log_stds[2])) # dihedral: n_obs x 1 if action is not None: dihedral = action[:, 5:6] elif self.training: dihedral = dihedral_dist.sample() else: dihedral = dihedral_mean # Kappa: 0 = keep, 1 = flip # surrogate_features: n_obs x n_afeats element_count_next = element_count - element_oh latent_bag_next = self.phi_beta(element_count_next) atomic_feats_next_0 = self.surrogate_features(observations, focus, element, distance, angle, dihedral) atomic_feats_next_1 = self.surrogate_features(observations, focus, element, distance, angle, -1 * dihedral) v0 = self.phi_kappa( torch.cat([atomic_feats_next_0, latent_bag_next], dim=-1)) v1 = self.phi_kappa( torch.cat([atomic_feats_next_1, latent_bag_next], dim=-1)) kappa_logits = torch.cat([v0, v1], dim=-1) kappa_dist = torch.distributions.Categorical(logits=kappa_logits) # kappa: n_obs x 1 if action is not None: kappa = torch.round(action[:, 6:7]) elif self.training: kappa = kappa_dist.sample().unsqueeze(-1) else: kappa = torch.argmax(kappa_logits, dim=-1).unsqueeze(-1) if action is None: action = torch.cat([ stop, focus.float(), element.float(), distance, angle, dihedral, kappa.float() ], dim=-1) # Critic weights = focus_mask.unsqueeze(-1).float() # n_obs x n_atoms x 1 weights = weights.transpose(1, 2) # n_obs x 1 x n_atoms sum_atomic_feats = (weights @ atomic_feats).squeeze( 1) # n_obs x n_afeats # mean_atomic_feats = sum_atomic_feats / torch.sum(focus_mask, dim=-1, keepdim=True) v = self.critic(torch.cat([sum_atomic_feats, latent_bag], dim=-1)) # Log probabilities log_prob_list = [ focus_dist.log_prob(focus.squeeze(-1)).unsqueeze(-1), element_dist.log_prob(element.squeeze(-1)).unsqueeze(-1), distance_dist.log_prob(distance), angle_dist.log_prob(angle), dihedral_dist.log_prob(dihedral), kappa_dist.log_prob(kappa.squeeze(-1)).unsqueeze(-1), ] log_prob = torch.cat(log_prob_list, dim=-1) # Mask log_prob = log_prob * action_mask # Entropies entropy_list = [ focus_dist.entropy().unsqueeze(-1), element_dist.entropy().unsqueeze(-1), distance_dist.entropy(), angle_dist.entropy(), dihedral_dist.entropy(), kappa_dist.entropy().unsqueeze(-1), ] entropy = torch.cat(entropy_list, dim=-1) # Mask entropy = entropy * action_mask return { 'a': action, # n_obs x n_subactions 'logp': log_prob.sum(dim=-1, keepdim=False), # n_obs 'ent': entropy[:, 0:2].sum(dim=-1, keepdim=False), # n_obs 'v': v.squeeze(-1), # n_obs # Other 'entropies': entropy, # n_obs x n_entropies }