Ejemplo n.º 1
0
 def test_invalid_action(self):
     formula = string_to_formula('H2CO')
     env = MolecularEnvironment(reward=self.reward,
                                observation_space=self.observation_space,
                                action_space=self.action_space,
                                formulas=[formula])
     action = self.action_space.from_atom(
         Atom(symbol='N', position=(0, 1, 0)))
     with self.assertRaises(RuntimeError):
         env.step(action)
Ejemplo n.º 2
0
    def test_addition(self):
        formula = string_to_formula('H2CO')
        env = MolecularEnvironment(reward=self.reward,
                                   observation_space=self.observation_space,
                                   action_space=self.action_space,
                                   formulas=[formula])
        action = self.action_space.from_atom(
            Atom(symbol='H', position=(0.0, 1.0, 0.0)))
        obs, reward, done, info = env.step(action=action)

        atoms1, formula = self.observation_space.parse(obs)

        self.assertEqual(atoms1[0].symbol, 'H')
        self.assertEqual(formula, ((0, 0), (1, 1), (6, 1), (7, 0), (8, 1)))
        self.assertEqual(reward, 0.0)
        self.assertFalse(done)
Ejemplo n.º 3
0
    def test_solo_distance(self):
        formula = string_to_formula('H2CO')
        env = MolecularEnvironment(
            reward=self.reward,
            observation_space=self.observation_space,
            action_space=self.action_space,
            formulas=[formula],
            max_solo_distance=1.0,
        )

        # First H can be on its own
        action = self.action_space.from_atom(
            atom=Atom(symbol='H', position=(0, 0, 0)))
        obs, reward, done, info = env.step(action=action)
        self.assertFalse(done)

        # Second H cannot
        action = self.action_space.from_atom(
            atom=Atom(symbol='H', position=(0, 1.5, 0)))
        obs, reward, done, info = env.step(action=action)
        self.assertTrue(done)
Ejemplo n.º 4
0
 def test_invalid_formula(self):
     formula = string_to_formula('He2')
     with self.assertRaises(AssertionError):
         self.observation_space.bag_space.from_formula(formula)
Ejemplo n.º 5
0
def main() -> None:
    config = get_config()

    util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']])

    tag = util.get_tag(config)
    util.setup_logger(config, directory=config['log_dir'], tag=tag)
    util.save_config(config, directory=config['log_dir'], tag=tag)

    util.set_seeds(seed=config['seed'])
    device = util.init_device(config['device'])

    zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')]
    action_space = ActionSpace(zs=zs)
    observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs)

    # Evaluation formulas
    if not config['eval_formulas']:
        config['eval_formulas'] = config['formulas']

    train_formulas = util.split_formula_strings(config['formulas'])
    eval_formulas = util.split_formula_strings(config['eval_formulas'])

    logging.info(f'Training bags: {train_formulas}')
    logging.info(f'Evaluation bags: {eval_formulas}')

    model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models'])

    if config['load_latest']:
        model, start_num_steps = model_handler.load_latest(device=device)
        model.action_space = action_space
        model.observation_space = observation_space
    elif config['load_model'] is not None:
        model, start_num_steps = model_handler.load(device=device, path=config['load_model'])
        model.action_space = action_space
        model.observation_space = observation_space
    else:
        model = build_model(config, observation_space=observation_space, action_space=action_space, device=device)
        start_num_steps = 0

    var_counts = util.count_vars(model)
    logging.info(f'Number of parameters: {var_counts}')

    reward = InteractionReward()

    # Number of episodes during evaluation
    if not config['num_eval_episodes']:
        config['num_eval_episodes'] = len(eval_formulas)

    training_envs = SimpleEnvContainer([
        MolecularEnvironment(
            reward=reward,
            observation_space=observation_space,
            action_space=action_space,
            formulas=[util.string_to_formula(f) for f in train_formulas],
            min_atomic_distance=config['min_atomic_distance'],
            max_solo_distance=config['max_solo_distance'],
            min_reward=config['min_reward'],
        ) for _ in range(config['num_envs'])
    ])

    eval_envs = SimpleEnvContainer([
        MolecularEnvironment(
            reward=reward,
            observation_space=observation_space,
            action_space=action_space,
            formulas=[util.string_to_formula(f) for f in eval_formulas],
            min_atomic_distance=config['min_atomic_distance'],
            max_solo_distance=config['max_solo_distance'],
            min_reward=config['min_reward'],
        )
    ])

    batch_ppo(
        envs=training_envs,
        eval_envs=eval_envs,
        ac=model,
        optimizer=util.get_optimizer(name=config['optimizer'],
                                     learning_rate=config['learning_rate'],
                                     parameters=model.parameters()),
        gamma=config['discount'],
        start_num_steps=start_num_steps,
        max_num_steps=config['max_num_steps'],
        num_steps_per_iter=config['num_steps_per_iter'],
        mini_batch_size=config['mini_batch_size'],
        clip_ratio=config['clip_ratio'],
        vf_coef=config['vf_coef'],
        entropy_coef=config['entropy_coef'],
        max_num_train_iters=config['max_num_train_iters'],
        lam=config['lam'],
        target_kl=config['target_kl'],
        gradient_clip=config['gradient_clip'],
        eval_freq=config['eval_freq'],
        model_handler=model_handler,
        save_freq=config['save_freq'],
        num_eval_episodes=config['num_eval_episodes'],
        rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag),
        save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all',
        save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all',
        info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag),
        device=device,
    )