def test_invariant(self): max_ell = 4 sphs_conj = SphericalHarmonics(maxl=max_ell, conj=True, sh_norm='unit') atomic_scalars = AtomicScalars(maxl=max_ell) theta_phi = np.array([[np.pi / 3, np.pi / 4], [2 * np.pi / 3, np.pi / 2]]) xyz_refs = spherical_to_cartesian(theta_phi) y_lms_conj = sphs_conj.forward( torch.tensor(xyz_refs, dtype=torch.float)) a_lms = estimate_alms(y_lms_conj) invariant = atomic_scalars(a_lms) self.assertTrue(invariant.shape[-1], atomic_scalars.get_output_dim(channels=1)) random_rotation = SO3WignerD.euler(maxl=max_ell, dtype=torch.float) a_lms_rotated = rotate_rep(random_rotation, a_lms) self.assertFalse( np.allclose(to_numpy(a_lms[1]), to_numpy(a_lms_rotated[1]))) invariant_rotated = atomic_scalars(a_lms_rotated) self.assertTrue(np.allclose(invariant, invariant_rotated))
def surrogate_features(self, observations: List[ObservationType], focus: torch.Tensor, element: torch.Tensor, distance: torch.Tensor, angle: torch.Tensor, dihedral: torch.Tensor) -> torch.Tensor: features = torch.zeros(size=(len(observations), self.num_afeats), dtype=torch.float32, device=self.device) focus = to_numpy(focus) element = to_numpy(element) distance = to_numpy(distance) angle = to_numpy(angle) dihedral = to_numpy(dihedral) for i, observation in enumerate(observations): atoms, _ = self.observation_space.parse(observation) positions = [atom.position for atom in atoms] new_position = zmat.position_atom_helper( positions=positions, focus=int(round(focus[i, 0])), distance=distance[i, 0], angle=angle[i, 0], dihedral=dihedral[i, 0], ) new_element = int(round(element[i, 0])) new_atom = ase.Atom( symbol=self.observation_space.bag_space.get_symbol( new_element), position=new_position) atoms.append(new_atom) features[i] = self.embedding_fn(self.converter(atoms))[:, -1, :] return features
def test_sample(self): torch.manual_seed(1) samples_shape = (2048, ) a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) samples = so3_distr.sample(samples_shape) self.assertEqual( samples.shape, samples_shape + so3_distr.batch_shape + so3_distr.event_shape) angles = cartesian_to_spherical(to_numpy(samples)) # [S, B, 2] mean_angles = np.mean(angles, axis=0) # [B, 2] self.assertEqual(mean_angles.shape, (2, 2)) so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs) samples_1 = so3_distr_1.sample(samples_shape) angles_1 = cartesian_to_spherical(to_numpy(samples_1)) # [S, 1, 2] mean_angles_1 = np.mean(angles_1, axis=0) # [1, 2] so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs) samples_2 = so3_distr_2.sample(samples_shape) angles_2 = cartesian_to_spherical(to_numpy(samples_2)) # [S, 1, 2] mean_angles_2 = np.mean(angles_2, axis=0) # [1, 2] # Assert that batching does not affect the result self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1)) self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1))
def batch_rollout(ac: AbstractActorCritic, envs: VecEnv, buffer_container: PPOBufferContainer, num_steps: int = None, num_episodes: int = None) -> dict: assert num_steps is not None or num_episodes is not None if num_steps is not None: assert num_steps % envs.get_size() == 0 num_iters = num_steps // envs.get_size() else: num_iters = np.inf if num_episodes is not None: assert envs.get_size() == 1 else: num_episodes = np.inf start_time = time.time() counter = 0 observations = envs.reset() while counter < num_iters and buffer_container.get_num_episodes() < num_episodes: predictions = ac.step(observations) next_observations, rewards, terminals, _ = envs.step(predictions['actions']) buffer_container.store(observations=observations, actions=to_numpy(predictions['a']), rewards=rewards, next_observations=next_observations, terminals=terminals, values=to_numpy(predictions['v']), logps=to_numpy(predictions['logp'])) # Reset environment if state is terminal to get valid next observation observations = envs.reset_if_terminal(next_observations, terminals) if counter == num_iters - 1: # Note: finished trajectories will not be affected by this predictions = ac.step(observations) buffer_container.finish_paths(to_numpy(predictions['v'])) counter += 1 info = { 'time': time.time() - start_time, 'return_mean': np.mean(buffer_container.episodic_returns).item(), 'return_std': np.std(buffer_container.episodic_returns).item(), 'episode_length_mean': np.mean(buffer_container.episode_lengths).item(), 'episode_length_std': np.std(buffer_container.episode_lengths).item(), } return info
def test_argmax(self): torch.manual_seed(1) argmax = self.distr.argmax(128) self.assertEqual(argmax.shape, (2, )) self.assertTrue( np.allclose(to_numpy(argmax), np.array([-0.495, 0.156]), atol=1.e-2))
def test_normalization(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float) grid = generate_fibonacci_grid(n=1024) grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1) probs = so3_distr.prob(grid_t) integral = 4 * np.pi * torch.mean(probs, dim=0) self.assertTrue(np.allclose(to_numpy(integral), 1.0))
def to_action_space(self, action: torch.Tensor, observation: ObservationType) -> ActionType: assert action.shape == (6, ) action = to_numpy(action) focus = int(round(action[0].item())) element_index = int(round(action[1].item())) d = action[2] so3 = action[-3:] atoms, bag = self.observation_space.parse(observation) if len(atoms): position = atoms[focus].position + d * so3 else: position = (0.0, 0.0, 0.0) return element_index, position
def compute_loss( ac: AbstractActorCritic, data: dict, clip_ratio: float, vf_coef: float, entropy_coef: float, device=None, ) -> Tuple[torch.Tensor, Dict[str, float]]: pred = ac.step(data['obs'], data['act']) old_logp = torch.as_tensor(data['logp'], device=device) adv = torch.as_tensor(data['adv'], device=device) ret = torch.as_tensor(data['ret'], device=device) # Policy loss ratio = torch.exp(pred['logp'] - old_logp) obj = ratio * adv clipped_obj = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv policy_loss = -torch.min(obj, clipped_obj).mean() # Entropy loss entropy_loss = -entropy_coef * pred['ent'].mean() # Value loss vf_loss = vf_coef * (pred['v'] - ret).pow(2).mean() # Total loss loss = policy_loss + entropy_loss + vf_loss # Approximate KL for early stopping approx_kl = (old_logp - pred['logp']).mean() # Extra info clipped = ratio.lt(1 - clip_ratio) | ratio.gt(1 + clip_ratio) clip_fraction = torch.as_tensor(clipped, dtype=torch.float32).mean() info = dict( policy_loss=to_numpy(policy_loss).item(), entropy_loss=to_numpy(entropy_loss).item(), vf_loss=to_numpy(vf_loss).item(), total_loss=to_numpy(loss).item(), approx_kl=to_numpy(approx_kl).item(), clip_fraction=to_numpy(clip_fraction).item(), ) return loss, info
def to_action_space(self, action: torch.Tensor, observation: ObservationType) -> ActionType: stop, focus, element, distance, angle, dihedral, kappa = to_numpy( action) if stop: return self.action_space.build(ase.Atoms()) # Round to obtain discrete subactions focus = int(round(focus)) element = int(round(element)) sign = -1 if int(round(kappa)) else 1 atoms, bag = self.observation_space.parse(observation) positions = [atom.position for atom in atoms] position = zmat.position_atom_helper(positions=positions, focus=focus, distance=distance, angle=angle, dihedral=sign * dihedral) atomic_number_index = self.action_space.zs.index( self.observation_space.bag_space.zs[element]) return atomic_number_index, tuple(position)
def test_multiplication_2(self): a = torch.tensor([2.0, 0.0], dtype=torch.float) b = torch.tensor([3.0, 0.0], dtype=torch.float) c = to_numpy(complex_product(a, b)) expected = np.array([6.0, 0.0]) self.assertTrue(np.allclose(c, expected))
def rollout(ac: AbstractActorCritic, env: AbstractMolecularEnvironment, buffer: PPOBuffer, num_steps: Optional[int] = None, num_episodes: Optional[int] = None) -> dict: assert num_steps or num_episodes num_steps = num_steps if num_steps is not None else np.inf num_episodes = num_episodes if num_episodes is not None else np.inf obs = env.reset() ep_returns = [] ep_lengths = [] ep_length = 0 ep_counter = 0 step = 0 start_time = time.time() while step < num_steps and ep_counter < num_episodes: pred = ac.step([obs]) a = to_numpy(pred['a'][0]) next_obs, reward, done, _ = env.step(ac.to_action_space(action=a, observation=obs)) buffer.store(obs=obs, act=a, reward=reward, next_obs=next_obs, terminal=done, value=pred['v'].item(), logp=pred['logp'].item()) obs = next_obs step += 1 ep_length += 1 last_step = step == num_steps - 1 if done or last_step: # if trajectory didn't reach terminal state, bootstrap value target of next observation if not done: pred = ac.step([obs]) value = float(pred['v']) else: value = 0 ep_return = buffer.finish_path(value) if done: ep_returns.append(ep_return) ep_lengths.append(ep_length) ep_counter += 1 obs = env.reset() ep_length = 0 # Compute statistics return_mean, return_std = mpi_mean_std(np.asarray(ep_returns), axis=0) ep_length_mean, ep_length_std = mpi_mean_std(np.asarray(ep_lengths), axis=0) value_mean, value_std = mpi_mean_std(buffer.val_buf[:buffer.ptr], axis=0) logp_mean, logp_std = mpi_mean_std(buffer.logp_buf[:buffer.ptr], axis=0) return { 'time': time.time() - start_time, 'num_steps': mpi_sum(np.asarray(step)).item(), 'return_mean': return_mean.item(), 'return_std': return_std.item(), 'value_mean': value_mean.item(), 'value_std': value_std.item(), 'logp_mean': logp_mean.item(), 'logp_std': logp_std.item(), 'episode_length_mean': ep_length_mean.item(), 'episode_length_std': ep_length_std.item(), }