예제 #1
0
    def test_invariant(self):
        max_ell = 4
        sphs_conj = SphericalHarmonics(maxl=max_ell, conj=True, sh_norm='unit')
        atomic_scalars = AtomicScalars(maxl=max_ell)

        theta_phi = np.array([[np.pi / 3, np.pi / 4],
                              [2 * np.pi / 3, np.pi / 2]])
        xyz_refs = spherical_to_cartesian(theta_phi)
        y_lms_conj = sphs_conj.forward(
            torch.tensor(xyz_refs, dtype=torch.float))

        a_lms = estimate_alms(y_lms_conj)

        invariant = atomic_scalars(a_lms)

        self.assertTrue(invariant.shape[-1],
                        atomic_scalars.get_output_dim(channels=1))

        random_rotation = SO3WignerD.euler(maxl=max_ell, dtype=torch.float)
        a_lms_rotated = rotate_rep(random_rotation, a_lms)

        self.assertFalse(
            np.allclose(to_numpy(a_lms[1]), to_numpy(a_lms_rotated[1])))

        invariant_rotated = atomic_scalars(a_lms_rotated)

        self.assertTrue(np.allclose(invariant, invariant_rotated))
예제 #2
0
    def surrogate_features(self, observations: List[ObservationType],
                           focus: torch.Tensor, element: torch.Tensor,
                           distance: torch.Tensor, angle: torch.Tensor,
                           dihedral: torch.Tensor) -> torch.Tensor:

        features = torch.zeros(size=(len(observations), self.num_afeats),
                               dtype=torch.float32,
                               device=self.device)
        focus = to_numpy(focus)
        element = to_numpy(element)
        distance = to_numpy(distance)
        angle = to_numpy(angle)
        dihedral = to_numpy(dihedral)

        for i, observation in enumerate(observations):
            atoms, _ = self.observation_space.parse(observation)
            positions = [atom.position for atom in atoms]
            new_position = zmat.position_atom_helper(
                positions=positions,
                focus=int(round(focus[i, 0])),
                distance=distance[i, 0],
                angle=angle[i, 0],
                dihedral=dihedral[i, 0],
            )
            new_element = int(round(element[i, 0]))
            new_atom = ase.Atom(
                symbol=self.observation_space.bag_space.get_symbol(
                    new_element),
                position=new_position)
            atoms.append(new_atom)
            features[i] = self.embedding_fn(self.converter(atoms))[:, -1, :]

        return features
예제 #3
0
    def test_sample(self):
        torch.manual_seed(1)
        samples_shape = (2048, )

        a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
        so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs)
        samples = so3_distr.sample(samples_shape)

        self.assertEqual(
            samples.shape,
            samples_shape + so3_distr.batch_shape + so3_distr.event_shape)

        angles = cartesian_to_spherical(to_numpy(samples))  # [S, B, 2]
        mean_angles = np.mean(angles, axis=0)  # [B, 2]

        self.assertEqual(mean_angles.shape, (2, 2))

        so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs)
        samples_1 = so3_distr_1.sample(samples_shape)
        angles_1 = cartesian_to_spherical(to_numpy(samples_1))  # [S, 1, 2]
        mean_angles_1 = np.mean(angles_1, axis=0)  # [1, 2]

        so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs)
        samples_2 = so3_distr_2.sample(samples_shape)
        angles_2 = cartesian_to_spherical(to_numpy(samples_2))  # [S, 1, 2]
        mean_angles_2 = np.mean(angles_2, axis=0)  # [1, 2]

        # Assert that batching does not affect the result
        self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1))
        self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1))
예제 #4
0
파일: ppo.py 프로젝트: gncs/molgym
def batch_rollout(ac: AbstractActorCritic,
                  envs: VecEnv,
                  buffer_container: PPOBufferContainer,
                  num_steps: int = None,
                  num_episodes: int = None) -> dict:
    assert num_steps is not None or num_episodes is not None

    if num_steps is not None:
        assert num_steps % envs.get_size() == 0
        num_iters = num_steps // envs.get_size()
    else:
        num_iters = np.inf

    if num_episodes is not None:
        assert envs.get_size() == 1
    else:
        num_episodes = np.inf

    start_time = time.time()

    counter = 0
    observations = envs.reset()

    while counter < num_iters and buffer_container.get_num_episodes() < num_episodes:
        predictions = ac.step(observations)

        next_observations, rewards, terminals, _ = envs.step(predictions['actions'])

        buffer_container.store(observations=observations,
                               actions=to_numpy(predictions['a']),
                               rewards=rewards,
                               next_observations=next_observations,
                               terminals=terminals,
                               values=to_numpy(predictions['v']),
                               logps=to_numpy(predictions['logp']))

        # Reset environment if state is terminal to get valid next observation
        observations = envs.reset_if_terminal(next_observations, terminals)

        if counter == num_iters - 1:
            # Note: finished trajectories will not be affected by this
            predictions = ac.step(observations)
            buffer_container.finish_paths(to_numpy(predictions['v']))

        counter += 1

    info = {
        'time': time.time() - start_time,
        'return_mean': np.mean(buffer_container.episodic_returns).item(),
        'return_std': np.std(buffer_container.episodic_returns).item(),
        'episode_length_mean': np.mean(buffer_container.episode_lengths).item(),
        'episode_length_std': np.std(buffer_container.episode_lengths).item(),
    }

    return info
예제 #5
0
파일: test_gmm.py 프로젝트: gncs/molgym
 def test_argmax(self):
     torch.manual_seed(1)
     argmax = self.distr.argmax(128)
     self.assertEqual(argmax.shape, (2, ))
     self.assertTrue(
         np.allclose(to_numpy(argmax),
                     np.array([-0.495, 0.156]),
                     atol=1.e-2))
예제 #6
0
 def test_normalization(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
     so3_distr = SO3Distribution(a_lms=a_lms,
                                 sphs=self.sphs,
                                 dtype=torch.float)
     grid = generate_fibonacci_grid(n=1024)
     grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1)
     probs = so3_distr.prob(grid_t)
     integral = 4 * np.pi * torch.mean(probs, dim=0)
     self.assertTrue(np.allclose(to_numpy(integral), 1.0))
예제 #7
0
파일: agent.py 프로젝트: gncs/molgym
    def to_action_space(self, action: torch.Tensor, observation: ObservationType) -> ActionType:
        assert action.shape == (6, )
        action = to_numpy(action)

        focus = int(round(action[0].item()))
        element_index = int(round(action[1].item()))
        d = action[2]
        so3 = action[-3:]

        atoms, bag = self.observation_space.parse(observation)

        if len(atoms):
            position = atoms[focus].position + d * so3
        else:
            position = (0.0, 0.0, 0.0)

        return element_index, position
예제 #8
0
파일: ppo.py 프로젝트: gncs/molgym
def compute_loss(
    ac: AbstractActorCritic,
    data: dict,
    clip_ratio: float,
    vf_coef: float,
    entropy_coef: float,
    device=None,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    pred = ac.step(data['obs'], data['act'])

    old_logp = torch.as_tensor(data['logp'], device=device)
    adv = torch.as_tensor(data['adv'], device=device)
    ret = torch.as_tensor(data['ret'], device=device)

    # Policy loss
    ratio = torch.exp(pred['logp'] - old_logp)
    obj = ratio * adv
    clipped_obj = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
    policy_loss = -torch.min(obj, clipped_obj).mean()

    # Entropy loss
    entropy_loss = -entropy_coef * pred['ent'].mean()

    # Value loss
    vf_loss = vf_coef * (pred['v'] - ret).pow(2).mean()

    # Total loss
    loss = policy_loss + entropy_loss + vf_loss

    # Approximate KL for early stopping
    approx_kl = (old_logp - pred['logp']).mean()

    # Extra info
    clipped = ratio.lt(1 - clip_ratio) | ratio.gt(1 + clip_ratio)
    clip_fraction = torch.as_tensor(clipped, dtype=torch.float32).mean()

    info = dict(
        policy_loss=to_numpy(policy_loss).item(),
        entropy_loss=to_numpy(entropy_loss).item(),
        vf_loss=to_numpy(vf_loss).item(),
        total_loss=to_numpy(loss).item(),
        approx_kl=to_numpy(approx_kl).item(),
        clip_fraction=to_numpy(clip_fraction).item(),
    )

    return loss, info
예제 #9
0
    def to_action_space(self, action: torch.Tensor,
                        observation: ObservationType) -> ActionType:
        stop, focus, element, distance, angle, dihedral, kappa = to_numpy(
            action)

        if stop:
            return self.action_space.build(ase.Atoms())

        # Round to obtain discrete subactions
        focus = int(round(focus))
        element = int(round(element))
        sign = -1 if int(round(kappa)) else 1

        atoms, bag = self.observation_space.parse(observation)
        positions = [atom.position for atom in atoms]
        position = zmat.position_atom_helper(positions=positions,
                                             focus=focus,
                                             distance=distance,
                                             angle=angle,
                                             dihedral=sign * dihedral)
        atomic_number_index = self.action_space.zs.index(
            self.observation_space.bag_space.zs[element])
        return atomic_number_index, tuple(position)
예제 #10
0
 def test_multiplication_2(self):
     a = torch.tensor([2.0, 0.0], dtype=torch.float)
     b = torch.tensor([3.0, 0.0], dtype=torch.float)
     c = to_numpy(complex_product(a, b))
     expected = np.array([6.0, 0.0])
     self.assertTrue(np.allclose(c, expected))
예제 #11
0
def rollout(ac: AbstractActorCritic,
            env: AbstractMolecularEnvironment,
            buffer: PPOBuffer,
            num_steps: Optional[int] = None,
            num_episodes: Optional[int] = None) -> dict:
    assert num_steps or num_episodes
    num_steps = num_steps if num_steps is not None else np.inf
    num_episodes = num_episodes if num_episodes is not None else np.inf

    obs = env.reset()

    ep_returns = []
    ep_lengths = []

    ep_length = 0
    ep_counter = 0
    step = 0

    start_time = time.time()

    while step < num_steps and ep_counter < num_episodes:
        pred = ac.step([obs])

        a = to_numpy(pred['a'][0])
        next_obs, reward, done, _ = env.step(ac.to_action_space(action=a, observation=obs))

        buffer.store(obs=obs,
                     act=a,
                     reward=reward,
                     next_obs=next_obs,
                     terminal=done,
                     value=pred['v'].item(),
                     logp=pred['logp'].item())

        obs = next_obs

        step += 1
        ep_length += 1

        last_step = step == num_steps - 1
        if done or last_step:
            # if trajectory didn't reach terminal state, bootstrap value target of next observation
            if not done:
                pred = ac.step([obs])
                value = float(pred['v'])
            else:
                value = 0

            ep_return = buffer.finish_path(value)

            if done:
                ep_returns.append(ep_return)
                ep_lengths.append(ep_length)
                ep_counter += 1

            obs = env.reset()
            ep_length = 0

    # Compute statistics
    return_mean, return_std = mpi_mean_std(np.asarray(ep_returns), axis=0)
    ep_length_mean, ep_length_std = mpi_mean_std(np.asarray(ep_lengths), axis=0)

    value_mean, value_std = mpi_mean_std(buffer.val_buf[:buffer.ptr], axis=0)
    logp_mean, logp_std = mpi_mean_std(buffer.logp_buf[:buffer.ptr], axis=0)

    return {
        'time': time.time() - start_time,
        'num_steps': mpi_sum(np.asarray(step)).item(),
        'return_mean': return_mean.item(),
        'return_std': return_std.item(),
        'value_mean': value_mean.item(),
        'value_std': value_std.item(),
        'logp_mean': logp_mean.item(),
        'logp_std': logp_std.item(),
        'episode_length_mean': ep_length_mean.item(),
        'episode_length_std': ep_length_std.item(),
    }