Beispiel #1
0
    def verify_alms(self, atoms):
        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]

        # Rotate
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]

        rotated_b_lms = so3_dist.coefficients.apply_wigner(wigner_d)
        for part1, part2 in zip(so3_dist_rot.coefficients, rotated_b_lms):
            max_delta = torch.max(torch.abs(part1 - part2))
            self.assertTrue(max_delta < 1e-5)
Beispiel #2
0
 def setUp(self) -> None:
     util.set_seeds(0)
     self.device = torch.device('cpu')
     self.action_space = ActionSpace(zs=[1])
     self.observation_space = ObservationSpace(canvas_size=5,
                                               zs=[0, 1, 6, 8])
     self.agent = CovariantAC(
         observation_space=self.observation_space,
         action_space=self.action_space,
         min_max_distance=(0.9, 1.8),
         network_width=64,
         bag_scale=1,
         device=self.device,
         beta=100,
         maxl=4,
         num_cg_levels=3,
         num_channels_hidden=10,
         num_channels_per_element=4,
         num_gaussians=3,
     )
     self.formula = ((1, 1), )
Beispiel #3
0
    def verify_probs(self, atoms):
        grid_points = torch.tensor(generate_fibonacci_grid(n=100_000),
                                   dtype=torch.float,
                                   device=self.device)
        grid_points = grid_points.unsqueeze(-2)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]

        # Rotate atoms
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms_rotated = atoms.copy()
        atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat,
                                            atoms.positions)

        observation = self.observation_space.build(atoms_rotated,
                                                   formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]

        log_probs = so3_dist.log_prob(grid_points)  # (samples, batches)
        log_probs_rot = so3_dist_rot.log_prob(
            grid_points)  # (samples, batches)

        # Maximum over grid points
        maximum, max_indices = torch.max(log_probs, dim=0)
        minimum, min_indices = torch.min(log_probs, dim=0)

        maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0)
        minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0)

        self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3))
        self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3))
Beispiel #4
0
    def verify_invariance(self, atoms):
        atomic_scalars = AtomicScalars(maxl=self.agent.max_sh)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]
        scalars = atomic_scalars(so3_dist.coefficients)

        # Rotate atoms
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms_rotated = atoms.copy()
        atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat,
                                            atoms.positions)

        observation = self.observation_space.build(atoms_rotated,
                                                   formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]
        scalars_rot = atomic_scalars(so3_dist_rot.coefficients)

        self.assertTrue(torch.allclose(scalars, scalars_rot, atol=1e-05))
Beispiel #5
0
def main() -> None:
    util.set_one_thread()
    # torch.set_num_threads(24)

    config = get_config()

    util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'],
                             config['results_dir'], config['structures_dir']])

    tag = util.get_tag(config)
    util.setup_logger(config, directory=config['log_dir'], tag=tag)
    util.save_config(config, directory=config['log_dir'], tag=tag)

    util.set_seeds(seed=config['seed'] + mpi.get_proc_rank())

    model_handler = util.ModelIO(directory=config['model_dir'], tag=tag)

    bag_symbols = config['bag_symbols'].split(',')
    action_space = ActionSpace()
    observation_space = ObservationSpace(canvas_size=config['canvas_size'], symbols=bag_symbols)

    start_num_steps = 0

    if config['loaded_model_name']:
        model = load_specific_model(model_path=config['loaded_model_name'])
        model.action_space = action_space
        model.observation_space = observation_space
    else:
        if not config['load_model']:
            model = build_model(config, observation_space=observation_space, action_space=action_space)
        else:
            model, start_num_steps = model_handler.load()
            model.action_space = action_space
            model.observation_space = observation_space

    mpi.sync_params(model)

    var_counts = util.count_vars(model)
    logging.info(f'Number of parameters: {var_counts}')

    reward = InteractionReward(config['rho'])

    # Evaluation formulas
    if not config['eval_formulas']:
        config['eval_formulas'] = config['formulas']

    train_formulas = parse_formulas(config['formulas'])
    eval_formulas = parse_formulas(config['eval_formulas'])

    train_init_formulas = parse_formulas(config['formulas'])
    eval_init_formulas = parse_formulas(config['eval_formulas'])

    logging.info(f'Training bags: {train_formulas}')
    logging.info(f'Evaluation bags: {eval_formulas}')

    # Number of episodes during evaluation
    if not config['num_eval_episodes']:
        config['num_eval_episodes'] = len(eval_formulas)

    env = MolecularEnvironment(
        reward=reward,
        observation_space=observation_space,
        action_space=action_space,
        formulas=train_formulas,
        min_atomic_distance=config['min_atomic_distance'],
        max_h_distance=config['max_h_distance'],
        min_reward=config['min_reward'],
        initial_formula=train_init_formulas,
        bag_refills=config['bag_refills'],
    )

    eval_env = MolecularEnvironment(
        reward=reward,
        observation_space=observation_space,
        action_space=action_space,
        formulas=eval_formulas,
        min_atomic_distance=config['min_atomic_distance'],
        max_h_distance=config['max_h_distance'],
        min_reward=config['min_reward'],
        initial_formula=eval_init_formulas,
        bag_refills=config['bag_refills'],
    )

    rollout_saver = RolloutSaver(directory=config['data_dir'], tag=tag, all_ranks=config['all_ranks'])
    info_saver = InfoSaver(directory=config['results_dir'], tag=tag)
    image_saver = StructureSaver(directory=config['structures_dir'], tag=tag)

    ppo(
        env=env,
        eval_env=eval_env,
        ac=model,
        gamma=config['discount'],
        start_num_steps=start_num_steps,
        max_num_steps=config['max_num_steps'],
        num_steps_per_iter=config['num_steps_per_iter'],
        clip_ratio=config['clip_ratio'],
        learning_rate=config['learning_rate'],
        vf_coef=config['vf_coef'],
        entropy_coef=config['entropy_coef'],
        max_num_train_iters=config['max_num_train_iters'],
        lam=config['lam'],
        target_kl=config['target_kl'],
        gradient_clip=config['gradient_clip'],
        eval_freq=config['eval_freq'],
        model_handler=model_handler,
        save_freq=config['save_freq'],
        num_eval_episodes=config['num_eval_episodes'],
        rollout_saver=rollout_saver,
        save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all',
        save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all',
        info_saver=info_saver,
        structure_saver=image_saver,
    )
Beispiel #6
0
def main() -> None:
    config = get_config()

    util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']])

    tag = util.get_tag(config)
    util.setup_logger(config, directory=config['log_dir'], tag=tag)
    util.save_config(config, directory=config['log_dir'], tag=tag)

    util.set_seeds(seed=config['seed'])
    device = util.init_device(config['device'])

    zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')]
    action_space = ActionSpace(zs=zs)
    observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs)

    # Evaluation formulas
    if not config['eval_formulas']:
        config['eval_formulas'] = config['formulas']

    train_formulas = util.split_formula_strings(config['formulas'])
    eval_formulas = util.split_formula_strings(config['eval_formulas'])

    logging.info(f'Training bags: {train_formulas}')
    logging.info(f'Evaluation bags: {eval_formulas}')

    model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models'])

    if config['load_latest']:
        model, start_num_steps = model_handler.load_latest(device=device)
        model.action_space = action_space
        model.observation_space = observation_space
    elif config['load_model'] is not None:
        model, start_num_steps = model_handler.load(device=device, path=config['load_model'])
        model.action_space = action_space
        model.observation_space = observation_space
    else:
        model = build_model(config, observation_space=observation_space, action_space=action_space, device=device)
        start_num_steps = 0

    var_counts = util.count_vars(model)
    logging.info(f'Number of parameters: {var_counts}')

    reward = InteractionReward()

    # Number of episodes during evaluation
    if not config['num_eval_episodes']:
        config['num_eval_episodes'] = len(eval_formulas)

    training_envs = SimpleEnvContainer([
        MolecularEnvironment(
            reward=reward,
            observation_space=observation_space,
            action_space=action_space,
            formulas=[util.string_to_formula(f) for f in train_formulas],
            min_atomic_distance=config['min_atomic_distance'],
            max_solo_distance=config['max_solo_distance'],
            min_reward=config['min_reward'],
        ) for _ in range(config['num_envs'])
    ])

    eval_envs = SimpleEnvContainer([
        MolecularEnvironment(
            reward=reward,
            observation_space=observation_space,
            action_space=action_space,
            formulas=[util.string_to_formula(f) for f in eval_formulas],
            min_atomic_distance=config['min_atomic_distance'],
            max_solo_distance=config['max_solo_distance'],
            min_reward=config['min_reward'],
        )
    ])

    batch_ppo(
        envs=training_envs,
        eval_envs=eval_envs,
        ac=model,
        optimizer=util.get_optimizer(name=config['optimizer'],
                                     learning_rate=config['learning_rate'],
                                     parameters=model.parameters()),
        gamma=config['discount'],
        start_num_steps=start_num_steps,
        max_num_steps=config['max_num_steps'],
        num_steps_per_iter=config['num_steps_per_iter'],
        mini_batch_size=config['mini_batch_size'],
        clip_ratio=config['clip_ratio'],
        vf_coef=config['vf_coef'],
        entropy_coef=config['entropy_coef'],
        max_num_train_iters=config['max_num_train_iters'],
        lam=config['lam'],
        target_kl=config['target_kl'],
        gradient_clip=config['gradient_clip'],
        eval_freq=config['eval_freq'],
        model_handler=model_handler,
        save_freq=config['save_freq'],
        num_eval_episodes=config['num_eval_episodes'],
        rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag),
        save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all',
        save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all',
        info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag),
        device=device,
    )