def verify_alms(self, atoms): observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] # Rotate wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] rotated_b_lms = so3_dist.coefficients.apply_wigner(wigner_d) for part1, part2 in zip(so3_dist_rot.coefficients, rotated_b_lms): max_delta = torch.max(torch.abs(part1 - part2)) self.assertTrue(max_delta < 1e-5)
def setUp(self) -> None: util.set_seeds(0) self.device = torch.device('cpu') self.action_space = ActionSpace(zs=[1]) self.observation_space = ObservationSpace(canvas_size=5, zs=[0, 1, 6, 8]) self.agent = CovariantAC( observation_space=self.observation_space, action_space=self.action_space, min_max_distance=(0.9, 1.8), network_width=64, bag_scale=1, device=self.device, beta=100, maxl=4, num_cg_levels=3, num_channels_hidden=10, num_channels_per_element=4, num_gaussians=3, ) self.formula = ((1, 1), )
def verify_probs(self, atoms): grid_points = torch.tensor(generate_fibonacci_grid(n=100_000), dtype=torch.float, device=self.device) grid_points = grid_points.unsqueeze(-2) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] # Rotate atoms wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms_rotated = atoms.copy() atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms_rotated, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] log_probs = so3_dist.log_prob(grid_points) # (samples, batches) log_probs_rot = so3_dist_rot.log_prob( grid_points) # (samples, batches) # Maximum over grid points maximum, max_indices = torch.max(log_probs, dim=0) minimum, min_indices = torch.min(log_probs, dim=0) maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0) minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0) self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3)) self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3))
def verify_invariance(self, atoms): atomic_scalars = AtomicScalars(maxl=self.agent.max_sh) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] scalars = atomic_scalars(so3_dist.coefficients) # Rotate atoms wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms_rotated = atoms.copy() atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms_rotated, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] scalars_rot = atomic_scalars(so3_dist_rot.coefficients) self.assertTrue(torch.allclose(scalars, scalars_rot, atol=1e-05))
def main() -> None: util.set_one_thread() # torch.set_num_threads(24) config = get_config() util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir'], config['structures_dir']]) tag = util.get_tag(config) util.setup_logger(config, directory=config['log_dir'], tag=tag) util.save_config(config, directory=config['log_dir'], tag=tag) util.set_seeds(seed=config['seed'] + mpi.get_proc_rank()) model_handler = util.ModelIO(directory=config['model_dir'], tag=tag) bag_symbols = config['bag_symbols'].split(',') action_space = ActionSpace() observation_space = ObservationSpace(canvas_size=config['canvas_size'], symbols=bag_symbols) start_num_steps = 0 if config['loaded_model_name']: model = load_specific_model(model_path=config['loaded_model_name']) model.action_space = action_space model.observation_space = observation_space else: if not config['load_model']: model = build_model(config, observation_space=observation_space, action_space=action_space) else: model, start_num_steps = model_handler.load() model.action_space = action_space model.observation_space = observation_space mpi.sync_params(model) var_counts = util.count_vars(model) logging.info(f'Number of parameters: {var_counts}') reward = InteractionReward(config['rho']) # Evaluation formulas if not config['eval_formulas']: config['eval_formulas'] = config['formulas'] train_formulas = parse_formulas(config['formulas']) eval_formulas = parse_formulas(config['eval_formulas']) train_init_formulas = parse_formulas(config['formulas']) eval_init_formulas = parse_formulas(config['eval_formulas']) logging.info(f'Training bags: {train_formulas}') logging.info(f'Evaluation bags: {eval_formulas}') # Number of episodes during evaluation if not config['num_eval_episodes']: config['num_eval_episodes'] = len(eval_formulas) env = MolecularEnvironment( reward=reward, observation_space=observation_space, action_space=action_space, formulas=train_formulas, min_atomic_distance=config['min_atomic_distance'], max_h_distance=config['max_h_distance'], min_reward=config['min_reward'], initial_formula=train_init_formulas, bag_refills=config['bag_refills'], ) eval_env = MolecularEnvironment( reward=reward, observation_space=observation_space, action_space=action_space, formulas=eval_formulas, min_atomic_distance=config['min_atomic_distance'], max_h_distance=config['max_h_distance'], min_reward=config['min_reward'], initial_formula=eval_init_formulas, bag_refills=config['bag_refills'], ) rollout_saver = RolloutSaver(directory=config['data_dir'], tag=tag, all_ranks=config['all_ranks']) info_saver = InfoSaver(directory=config['results_dir'], tag=tag) image_saver = StructureSaver(directory=config['structures_dir'], tag=tag) ppo( env=env, eval_env=eval_env, ac=model, gamma=config['discount'], start_num_steps=start_num_steps, max_num_steps=config['max_num_steps'], num_steps_per_iter=config['num_steps_per_iter'], clip_ratio=config['clip_ratio'], learning_rate=config['learning_rate'], vf_coef=config['vf_coef'], entropy_coef=config['entropy_coef'], max_num_train_iters=config['max_num_train_iters'], lam=config['lam'], target_kl=config['target_kl'], gradient_clip=config['gradient_clip'], eval_freq=config['eval_freq'], model_handler=model_handler, save_freq=config['save_freq'], num_eval_episodes=config['num_eval_episodes'], rollout_saver=rollout_saver, save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all', save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all', info_saver=info_saver, structure_saver=image_saver, )
def main() -> None: config = get_config() util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']]) tag = util.get_tag(config) util.setup_logger(config, directory=config['log_dir'], tag=tag) util.save_config(config, directory=config['log_dir'], tag=tag) util.set_seeds(seed=config['seed']) device = util.init_device(config['device']) zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')] action_space = ActionSpace(zs=zs) observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs) # Evaluation formulas if not config['eval_formulas']: config['eval_formulas'] = config['formulas'] train_formulas = util.split_formula_strings(config['formulas']) eval_formulas = util.split_formula_strings(config['eval_formulas']) logging.info(f'Training bags: {train_formulas}') logging.info(f'Evaluation bags: {eval_formulas}') model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models']) if config['load_latest']: model, start_num_steps = model_handler.load_latest(device=device) model.action_space = action_space model.observation_space = observation_space elif config['load_model'] is not None: model, start_num_steps = model_handler.load(device=device, path=config['load_model']) model.action_space = action_space model.observation_space = observation_space else: model = build_model(config, observation_space=observation_space, action_space=action_space, device=device) start_num_steps = 0 var_counts = util.count_vars(model) logging.info(f'Number of parameters: {var_counts}') reward = InteractionReward() # Number of episodes during evaluation if not config['num_eval_episodes']: config['num_eval_episodes'] = len(eval_formulas) training_envs = SimpleEnvContainer([ MolecularEnvironment( reward=reward, observation_space=observation_space, action_space=action_space, formulas=[util.string_to_formula(f) for f in train_formulas], min_atomic_distance=config['min_atomic_distance'], max_solo_distance=config['max_solo_distance'], min_reward=config['min_reward'], ) for _ in range(config['num_envs']) ]) eval_envs = SimpleEnvContainer([ MolecularEnvironment( reward=reward, observation_space=observation_space, action_space=action_space, formulas=[util.string_to_formula(f) for f in eval_formulas], min_atomic_distance=config['min_atomic_distance'], max_solo_distance=config['max_solo_distance'], min_reward=config['min_reward'], ) ]) batch_ppo( envs=training_envs, eval_envs=eval_envs, ac=model, optimizer=util.get_optimizer(name=config['optimizer'], learning_rate=config['learning_rate'], parameters=model.parameters()), gamma=config['discount'], start_num_steps=start_num_steps, max_num_steps=config['max_num_steps'], num_steps_per_iter=config['num_steps_per_iter'], mini_batch_size=config['mini_batch_size'], clip_ratio=config['clip_ratio'], vf_coef=config['vf_coef'], entropy_coef=config['entropy_coef'], max_num_train_iters=config['max_num_train_iters'], lam=config['lam'], target_kl=config['target_kl'], gradient_clip=config['gradient_clip'], eval_freq=config['eval_freq'], model_handler=model_handler, save_freq=config['save_freq'], num_eval_episodes=config['num_eval_episodes'], rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag), save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all', save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all', info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag), device=device, )