Пример #1
0
 def test_do_not_allow_no_modification(self):
     mol = molecules.Molecule({'C', 'O'},
                              'C#C',
                              allow_no_modification=False)
     mol.initialize()
     actions_noallow_no_modification = mol.get_valid_actions()
     mol = molecules.Molecule({'C', 'O'}, 'C#C', allow_no_modification=True)
     mol.initialize()
     actions_allow_no_modification = mol.get_valid_actions()
     self.assertSetEqual({'C#C'}, actions_allow_no_modification -
                         actions_noallow_no_modification)
Пример #2
0
 def test_benzene_action(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1')
     mol.initialize()
     result = mol.step('Cc1ccccc1')
     self.assertEqual(result.state, 'Cc1ccccc1')
     self.assertEqual(result.reward, 0)
     self.assertEqual(result.terminated, False)
Пример #3
0
 def test_empty_action(self):
     mol = molecules.Molecule({'C', 'O'})
     mol.initialize()
     result = mol.step('C')
     self.assertEqual(result.state, 'C')
     self.assertEqual(result.reward, 0)
     self.assertEqual(result.terminated, False)
Пример #4
0
    def test_do_not_allow_bonding_between_rings(self):
        atom_types = {'C'}
        start_smiles = 'CC12CC1C2'
        mol = molecules.Molecule(atom_types,
                                 start_smiles,
                                 allow_bonds_between_rings=True)
        mol.initialize()
        actions_true = mol.get_valid_actions()
        mol = molecules.Molecule(atom_types,
                                 start_smiles,
                                 allow_bonds_between_rings=False)
        mol.initialize()
        actions_false = mol.get_valid_actions()

        self.assertSetEqual({'CC12C3C1C32', 'CC12C3=C1C32'},
                            actions_true - actions_false)
Пример #5
0
def get_optimized_mols(model_dir, ckpt=80000):
    """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
    hparams_file = os.path.join(model_dir, 'config.json')
    with gfile.Open(hparams_file, 'r') as f:
        hp_dict = json.load(f)
        hparams = deep_q_networks.get_hparams(**hp_dict)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=0.0)

    tf.reset_default_graph()
    optimized_mol = []
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
        for mol in all_mols:
            logging.info('Eval: %s', mol)
            environment = molecules_mdp.Molecule(
                atom_types=set(hparams.atom_types),
                init_mol=mol,
                allow_removal=hparams.allow_removal,
                allow_no_modification=hparams.allow_no_modification,
                allow_bonds_between_rings=hparams.allow_bonds_between_rings,
                allowed_ring_sizes=set(hparams.allowed_ring_sizes),
                max_steps=hparams.max_steps_per_episode,
                record_path=True)
            environment.initialize()
            if hparams.num_bootstrap_heads:
                head = np.random.randint(hparams.num_bootstrap_heads)
            else:
                head = 0
            for _ in range(hparams.max_steps_per_episode):
                steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
                valid_actions = list(environment.get_valid_actions())
                observations = np.vstack([
                    np.append(deep_q_networks.get_fingerprint(act, hparams),
                              steps_left) for act in valid_actions
                ])
                action = valid_actions[dqn.get_action(observations,
                                                      head=head,
                                                      update_epsilon=0.0)]
                environment.step(action)
            optimized_mol.append(environment.get_path())
    return optimized_mol
Пример #6
0
    def test_limited_ring_formation(self):
        atom_types = {'C'}
        start_smiles = 'CCCCC'
        mol = molecules.Molecule(atom_types,
                                 start_smiles,
                                 allowed_ring_sizes={3, 4, 5})
        mol.initialize()
        actions_allow_5_member_ring = mol.get_valid_actions()
        mol = molecules.Molecule(atom_types,
                                 start_smiles,
                                 allowed_ring_sizes={3, 4})
        mol.initialize()
        actions_do_not_allow_5_member_ring = mol.get_valid_actions()

        self.assertSetEqual({'C1CCCC1', 'C1#CCCC1', 'C1=CCCC1'},
                            actions_allow_5_member_ring -
                            actions_do_not_allow_5_member_ring)
Пример #7
0
 def test_state_transition(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1')
     mol.initialize()
     result = mol.step('Cc1ccccc1')
     self.assertEqual(result.state, 'Cc1ccccc1')
     self.assertEqual(result.reward, 0)
     self.assertEqual(result.terminated, False)
     self.assertEqual(mol.state, 'Cc1ccccc1')
     self.assertEqual(mol.num_steps_taken, 1)
Пример #8
0
 def test_end_episode(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1', max_steps=3)
     mol.initialize()
     for _ in range(3):
         action = mol.get_valid_actions().pop()
         result = mol.step(action)
     self.assertEqual(result.terminated, True)
     with self.assertRaisesRegexp(ValueError,
                                  'This episode is terminated.'):
         mol.step(mol.get_valid_actions().pop())
Пример #9
0
 def test_cyclobutane_init(self):
     # We want to know that it is possible to form another
     # ring when there is one ring present.
     mol = molecules.Molecule({'C', 'O'}, 'C1CCC1')
     mol.initialize()
     self.assertSetEqual(
         mol.get_valid_actions(), {
             'C1CCC1', 'C=C1CCC1', 'C1C2CC12', 'C1=CCC1', 'CCCC',
             'O=C1CCC1', 'CC1CCC1', 'OC1CCC1', 'C1#CCC1', 'C1C2=C1C2'
         })
Пример #10
0
 def test_record(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1', record_path=True)
     mol.initialize()
     mol.step('Cc1ccccc1')
     mol.step('CCc1ccccc1')
     mol.step('Cc1ccccc1')
     mol.step('c1ccccc1')
     self.assertListEqual(
         mol.get_path(),
         ['c1ccccc1', 'Cc1ccccc1', 'CCc1ccccc1', 'Cc1ccccc1', 'c1ccccc1'])
Пример #11
0
 def test_goal_settings(self):
     mol = molecules.Molecule({'C', 'O'},
                              'c1ccccc1',
                              target_fn=lambda x: x == 'Cc1ccccc1')
     mol.initialize()
     result = mol.step('Cc1ccccc1')
     self.assertEqual(result.state, 'Cc1ccccc1')
     self.assertEqual(result.reward, 0)
     self.assertEqual(result.terminated, True)
     with self.assertRaisesRegexp(ValueError,
                                  'This episode is terminated.'):
         mol.step(mol.get_valid_actions().pop())
Пример #12
0
 def test_initialize(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1', record_path=True)
     mol.initialize()
     # Test if the molecule is correctly initialized.
     self.assertEqual(mol.state, 'c1ccccc1')
     self.assertEqual(mol.num_steps_taken, 0)
     self.assertListEqual(mol.get_path(), ['c1ccccc1'])
     # Take a step
     result = mol.step('Cc1ccccc1')
     self.assertEqual(result.state, 'Cc1ccccc1')
     self.assertEqual(result.reward, 0)
     self.assertListEqual(mol.get_path(), ['c1ccccc1', 'Cc1ccccc1'])
     # Test if the molecule is reset to its initial state.
     mol.initialize()
     self.assertEqual(mol.state, 'c1ccccc1')
     self.assertEqual(mol.num_steps_taken, 0)
     self.assertListEqual(mol.get_path(), ['c1ccccc1'])
Пример #13
0
 def test_empty_init(self):
     mol = molecules.Molecule({'C', 'O'})
     mol.initialize()
     self.assertSetEqual(mol.get_valid_actions(), {'C', 'O'})
Пример #14
0
 def test_invalid_actions(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1')
     mol.initialize()
     with self.assertRaisesRegexp(ValueError, 'Invalid action.'):
         mol.step('C')
Пример #15
0
 def test_do_not_allow_removal(self):
     mol = molecules.Molecule({'C', 'O'}, 'CC', allow_removal=False)
     mol.initialize()
     self.assertSetEqual(
         mol.get_valid_actions(),
         {'CC', 'CCC', 'C#CC', 'CCO', 'CC=O', 'C=CC', 'C=C', 'C#C'})
Пример #16
0
 def test_ethane_init(self):
     mol = molecules.Molecule({'C', 'O'}, 'CC')
     mol.initialize()
     self.assertSetEqual(
         mol.get_valid_actions(),
         {'CC', 'C=C', 'CCC', 'C#CC', 'CCO', 'CC=O', 'C', 'C=CC', 'C#C'})
Пример #17
0
 def test_episode_not_started(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1')
     with self.assertRaisesRegexp(ValueError,
                                  'This episode is terminated.'):
         mol.step('Cc1ccccc1')
Пример #18
0
 def test_benzene_init(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1')
     mol.initialize()
     self.assertSetEqual(mol.get_valid_actions(), {
         'Oc1ccccc1', 'c1ccccc1', 'Cc1ccccc1', 'c1cc2cc-2c1', 'c1cc2ccc1-2'
     })
Пример #19
0
 def test_image_generation(self):
     mol = molecules.Molecule({'C', 'O'}, 'c1ccccc1', max_steps=3)
     mol.initialize()
     image = mol.visualize_state()
     del image