def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    hparams.override_from_dict(
        {'max_steps_per_episode': FLAGS.max_steps_per_episode})

    environment = Molecule(atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#2
0
def main(argv):
    del argv  # unused.
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    filename = 'all_800_mols.json'
    with gfile.Open(filename) as fp:
        all_molecules = json.load(fp)

    environment = LogPRewardWithSimilarityConstraintMolecule(
        similarity_constraint=FLAGS.similarity_constraint,
        discount_factor=hparams.discount_factor,
        all_molecules=all_molecules,
        atom_types=set(hparams.atom_types),
        init_mol=None,
        allow_removal=hparams.allow_removal,
        allow_no_modification=hparams.allow_no_modification,
        allow_bonds_between_rings=hparams.allow_bonds_between_rings,
        allowed_ring_sizes=set(hparams.allowed_ring_sizes),
        max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#3
0
def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = run_dqn.get_hparams(**json.load(f))
    else:
        hparams = run_dqn.get_hparams()

    environment = Molecule(target_weight=FLAGS.target_weight,
                           atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    if FLAGS.error_type.lower() == 'l2':
        klass = deep_q_networks_noise.DeepQNetworkL2
    else:
        klass = deep_q_networks_noise.DeepQNetwork
    dqn = klass(input_shape=(hparams.batch_size, hparams.fingerprint_length),
                q_fn=functools.partial(deep_q_networks_noise.multi_layer_model,
                                       hparams=hparams),
                optimizer=hparams.optimizer,
                grad_clipping=hparams.grad_clipping,
                num_bootstrap_heads=hparams.num_bootstrap_heads,
                gamma=hparams.gamma,
                epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    hparams.add_hparam('noise_std', FLAGS.noise_std)
    hparams.add_hparam('error_type', FLAGS.error_type)
    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#4
0
def main(argv):
  del argv
  if FLAGS.hparams is not None:
    with gfile.Open(FLAGS.hparams, 'r') as f:
      hparams = deep_q_networks.get_hparams(**json.load(f))
  else:
    hparams = deep_q_networks.get_hparams()

  hparams.add_hparam('target_qed', FLAGS.target_qed)
  hparams.add_hparam('target_sas', FLAGS.target_sas)

  environment = Molecule(
      atom_types=set(hparams.atom_types),
      init_mol='CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12',
      allow_removal=hparams.allow_removal,
      allow_no_modification=hparams.allow_no_modification,
      allow_bonds_between_rings=False,
      allowed_ring_sizes={3, 4, 5, 6},
      max_steps=hparams.max_steps_per_episode)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=1.0)

  run_dqn.run_training(
      hparams=hparams,
      environment=environment,
      dqn=dqn,
  )

  core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))