Ejemplo n.º 1
0
    def test_softmax(self):
        logits = torch.from_numpy(
            np.array([
                [0.5, 0.5],
                [1.0, 0.5],
            ], dtype=np.float))

        mask_1 = torch.ones(size=logits.shape, dtype=torch.bool)

        y1 = masked_softmax(logits=logits, mask=mask_1)
        self.assertEqual(y1.shape, (2, 2))
        self.assertAlmostEqual(y1.sum().item(), 2.0)

        mask_2 = torch.from_numpy(np.array([[1, 0], [1, 0]], dtype=np.bool))
        y2 = masked_softmax(logits=logits, mask=mask_2)

        total = y2.sum(dim=0, keepdim=False)
        self.assertTrue(np.allclose(total, np.array([2, 0])))
Ejemplo n.º 2
0
Archivo: agent.py Proyecto: gncs/molgym
    def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict:
        data = self.parse_observations(observations)

        # Cast action to tensor
        if actions is not None:
            actions = torch.as_tensor(actions, dtype=torch.float, device=self.device)

        # SO3Vec (batches, atoms, taus, ms, 2)
        covariats = self.cg_model(data)

        # Compute invariants
        invariats = self.atomic_scalars(covariats)  # (batches, atoms, inv_feats)

        # Focus
        focus_logits = self.phi_focus(invariats)  # (batches, atoms, 1)
        focus_logits = focus_logits.squeeze(-1)  # (batches, atoms)
        focus_probs = masked_softmax(focus_logits, mask=data['focus_mask'])  # (batches, atoms)
        focus_dist = torch.distributions.Categorical(probs=focus_probs)

        # focus: (batches, 1)
        if actions is not None:
            focus = torch.round(actions[:, :1]).long()
        elif self.training:
            focus = focus_dist.sample().unsqueeze(-1)
        else:
            focus = torch.argmax(focus_probs, dim=-1).unsqueeze(-1)

        focus_oh = to_one_hot(focus, num_classes=self.observation_space.canvas_space.size,
                              device=self.device)  # (batches, atoms)

        focused_cov = so3_tools.select_atomic_covariats(covariats, focus_oh)  # (batches, taus, ms, 2)
        focused_inv = so3_tools.select_atomic_invariats(invariats, focus_oh)  # (batches, feats)

        # Element
        element_logits = self.phi_element(focused_inv)  # (batches, zs)
        element_probs = masked_softmax(element_logits, mask=data['element_mask'])  # (batches, zs)
        element_dist = torch.distributions.Categorical(probs=element_probs)

        # element: (batches, 1)
        if actions is not None:
            element = torch.round(actions[:, 1:2]).long()
        elif self.training:
            element = element_dist.sample().unsqueeze(-1)
        else:
            element = torch.argmax(element_probs, dim=-1).unsqueeze(-1)

        # Crop element
        offsets = self.channel_offsets.expand(len(observations), -1)  # (batches, channels_per_element)
        indices = offsets + element * self.num_channels_per_element
        element_cov = so3_tools.select_taus(focused_cov, indices=indices)
        element_inv = self.atomic_scalars(element_cov)  # (batches, inv_feats)

        # Distance: Gaussian mixture model
        # gmm_log_probs, d_mean_trans: (batches, gaussians)
        gmm_log_probs, d_mean_trans = self.phi_d(element_inv).split(self.num_gaussians, dim=-1)
        distance_mean = torch.tanh(d_mean_trans) * self.distance_half_width + self.distance_center
        distance_dist = GaussianMixtureModel(log_probs=gmm_log_probs,
                                             means=distance_mean,
                                             stds=torch.exp(self.distance_log_stds).clamp(1e-6))

        # distance: (batches, 1)
        if actions is not None:
            distance = actions[:, 2:3]
        elif self.training:
            # Ensure that the sampled distance is > 0
            distance = distance_dist.sample().clamp(0.001).unsqueeze(-1)
        else:
            distance = distance_dist.argmax().unsqueeze(-1)

        # Condition on distance
        transformed_d = distance.unsqueeze(1).unsqueeze(1).expand(-1, self.num_channels_per_element, 1, -1)
        transformed_d = self.pad_zeros(transformed_d)
        distance_so3 = SO3Vec([transformed_d])
        cond_cov = self.cg_mix(element_cov, distance_so3)

        so3_dist = self.get_so3_distribution(a_lms=cond_cov, empty=data['empty'])

        # so3: (batches, 3)
        if actions is not None:
            orientation = actions[..., 3:6]
        elif self.training:
            orientation = so3_dist.sample()
        else:
            orientation = so3_dist.argmax()

        # Log prob
        log_prob_list = [
            focus_dist.log_prob(focus.squeeze(-1)),
            element_dist.log_prob(element.squeeze(-1)),
            distance_dist.log_prob(distance.squeeze(-1)),
            so3_dist.log_prob(orientation),
        ]
        log_prob = torch.stack(log_prob_list, dim=-1).sum(dim=-1)  # (batches, )

        # Entropy
        entropy_list = [
            focus_dist.entropy(),
            element_dist.entropy(),
        ]
        entropy = torch.stack(entropy_list, dim=-1).sum(dim=-1)  # (batches, )

        # Value function
        # atom_mask: (batches, atoms)
        # invariants: (batches, atoms, feats)
        trans_invariats = self.phi_trans(invariats)
        value_feats = torch.einsum(  # type: ignore
            'ba,baf->bf', data['value_mask'].to(self.dtype), trans_invariats)  # (batches, inv_feats)
        value = self.phi_v(value_feats).squeeze(-1)  # (batches, )

        # Action
        response: Dict[str, Any] = {}
        if actions is None:
            actions = torch.cat([focus.float(), element.float(), distance, orientation], dim=-1)

            # Build correspond action in action space
            response['actions'] = [self.to_action_space(a, o) for a, o in zip(actions, observations)]

        response.update({
            'a': actions,  # (batches, subactions)
            'logp': log_prob,  # (batches, )
            'ent': entropy,  # (batches, )
            'v': value,  # (batches, )
            'dists': [focus_dist, element_dist, distance_dist, so3_dist],
        })

        return response
Ejemplo n.º 3
0
    def step(self,
             observations: List[ObservationType],
             action: Optional[np.ndarray] = None) -> dict:
        # atomic_feats: n_obs x n_atoms x n_afeats
        # focus_mask: n_obs x n_atoms
        # focus_mask: n_obs x n_atoms
        # element_count: n_obs x n_zs

        atomic_feats, focus_mask, focus_mask_next, element_count, action_mask = self.make_atomic_tensors(
            observations)
        element_mask = (element_count > 0).int()

        # stop: this agent does not stop
        stop = torch.zeros(size=(len(observations), 1),
                           dtype=torch.float,
                           device=self.device)

        # latent states bag
        latent_bag = self.phi_beta(element_count)

        # latent representation of atoms and bag
        latent_bag_tiled = latent_bag.unsqueeze(1)  # n_obs x 1 x n_zs
        latent_bag_tiled = latent_bag_tiled.expand(
            -1, self.num_atoms, -1)  # n_obs x n_atoms x n_zs

        latent_states = torch.cat(
            [atomic_feats, latent_bag_tiled],
            dim=-1)  # n_obs x n_atoms x (n_afeats + n_zs)

        # Focus
        focus_logits = self.phi_focus(latent_states)  # n_obs x n_atoms x 1
        focus_logits = focus_logits.squeeze(-1)  # n_obs x n_atoms

        focus_p = masked_softmax(focus_logits,
                                 mask=focus_mask)  # n_obs x n_atoms
        focus_dist = torch.distributions.Categorical(probs=focus_p)

        # Cast action to Tensor
        if action is not None:
            action = torch.as_tensor(action, device=self.device)

        # focus: n_obs x 1
        if action is not None:
            focus = torch.round(action[:, 1:2]).long()
        elif self.training:
            focus = focus_dist.sample().unsqueeze(-1)
        else:
            focus = torch.argmax(focus_p, dim=-1).unsqueeze(-1)

        focus_oh = to_one_hot(focus,
                              num_classes=self.num_atoms)  # n_obs x n_atoms

        # Focused atom is a hard (one-hot) selection over atoms
        focused_atom = (
            latent_states.transpose(1, 2) @ focus_oh[:, :, None]).squeeze(
                -1)  # n_obs x n_latent

        # Element
        element_logits = self.phi_element(focused_atom)  # n_obs x n_zs
        element_p = masked_softmax(element_logits,
                                   mask=element_mask)  # n_obs x n_zs
        element_dist = torch.distributions.Categorical(probs=element_p)

        # element: n_obs x 1
        if action is not None:
            element = torch.round(action[:, 2:3]).long()
        elif self.training:
            element = element_dist.sample().unsqueeze(-1)
        else:
            element = torch.argmax(element_p, dim=-1).unsqueeze(-1)

        element_oh = to_one_hot(element, self.num_zs)  # n_obs x n_zs

        # Continuous variables
        # f: n_obs x (n_latent + n_zs)
        f = torch.cat([focused_atom, element_oh], dim=-1)
        distance_mean, angle_mean, dihedral_mean = torch.split(torch.tanh(
            self.phi_continuous(f)),
                                                               1,
                                                               dim=-1)

        # Distance
        distance_mean = (distance_mean * self.action_width[0] /
                         2) + self.action_center[0]
        distance_dist = torch.distributions.Normal(
            loc=distance_mean, scale=torch.exp(1e-6 + self.log_stds[0]))

        # distance: n_obs x 1
        if action is not None:
            distance = action[:, 3:4]
        elif self.training:
            # Ensure that the sampled distance is > 0
            distance = distance_dist.sample().clamp(0.001)
        else:
            distance = distance_mean

        # Angle
        angle_mean = (angle_mean * self.action_width[1] /
                      2) + self.action_center[1]
        angle_dist = torch.distributions.Normal(
            loc=angle_mean, scale=torch.exp(1e-6 + self.log_stds[1]))

        # angle: n_obs x 1
        if action is not None:
            angle = action[:, 4:5]
        elif self.training:
            angle = angle_dist.sample()
        else:
            angle = angle_mean

        # Dihedral
        dihedral_mean = (dihedral_mean * self.action_width[2] /
                         2) + self.action_center[2]
        dihedral_dist = torch.distributions.Normal(
            loc=dihedral_mean, scale=torch.exp(1e-6 + self.log_stds[2]))

        # dihedral: n_obs x 1
        if action is not None:
            dihedral = action[:, 5:6]
        elif self.training:
            dihedral = dihedral_dist.sample()
        else:
            dihedral = dihedral_mean

        # Kappa: 0 = keep, 1 = flip
        # surrogate_features: n_obs x n_afeats
        element_count_next = element_count - element_oh
        latent_bag_next = self.phi_beta(element_count_next)

        atomic_feats_next_0 = self.surrogate_features(observations, focus,
                                                      element, distance, angle,
                                                      dihedral)
        atomic_feats_next_1 = self.surrogate_features(observations, focus,
                                                      element, distance, angle,
                                                      -1 * dihedral)

        v0 = self.phi_kappa(
            torch.cat([atomic_feats_next_0, latent_bag_next], dim=-1))
        v1 = self.phi_kappa(
            torch.cat([atomic_feats_next_1, latent_bag_next], dim=-1))

        kappa_logits = torch.cat([v0, v1], dim=-1)
        kappa_dist = torch.distributions.Categorical(logits=kappa_logits)

        # kappa: n_obs x 1
        if action is not None:
            kappa = torch.round(action[:, 6:7])
        elif self.training:
            kappa = kappa_dist.sample().unsqueeze(-1)
        else:
            kappa = torch.argmax(kappa_logits, dim=-1).unsqueeze(-1)

        if action is None:
            action = torch.cat([
                stop,
                focus.float(),
                element.float(), distance, angle, dihedral,
                kappa.float()
            ],
                               dim=-1)

        # Critic
        weights = focus_mask.unsqueeze(-1).float()  # n_obs x n_atoms x 1
        weights = weights.transpose(1, 2)  # n_obs x 1 x n_atoms
        sum_atomic_feats = (weights @ atomic_feats).squeeze(
            1)  # n_obs x n_afeats
        # mean_atomic_feats = sum_atomic_feats / torch.sum(focus_mask, dim=-1, keepdim=True)
        v = self.critic(torch.cat([sum_atomic_feats, latent_bag], dim=-1))

        # Log probabilities
        log_prob_list = [
            focus_dist.log_prob(focus.squeeze(-1)).unsqueeze(-1),
            element_dist.log_prob(element.squeeze(-1)).unsqueeze(-1),
            distance_dist.log_prob(distance),
            angle_dist.log_prob(angle),
            dihedral_dist.log_prob(dihedral),
            kappa_dist.log_prob(kappa.squeeze(-1)).unsqueeze(-1),
        ]
        log_prob = torch.cat(log_prob_list, dim=-1)

        # Mask
        log_prob = log_prob * action_mask

        # Entropies
        entropy_list = [
            focus_dist.entropy().unsqueeze(-1),
            element_dist.entropy().unsqueeze(-1),
            distance_dist.entropy(),
            angle_dist.entropy(),
            dihedral_dist.entropy(),
            kappa_dist.entropy().unsqueeze(-1),
        ]
        entropy = torch.cat(entropy_list, dim=-1)

        # Mask
        entropy = entropy * action_mask

        return {
            'a': action,  # n_obs x n_subactions
            'logp': log_prob.sum(dim=-1, keepdim=False),  # n_obs
            'ent': entropy[:, 0:2].sum(dim=-1, keepdim=False),  # n_obs
            'v': v.squeeze(-1),  # n_obs

            # Other
            'entropies': entropy,  # n_obs x n_entropies
        }