Пример #1
0
def get_optimized_mols(model_dir, ckpt=80000):
    """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
    hparams_file = os.path.join(model_dir, 'config.json')
    with gfile.Open(hparams_file, 'r') as f:
        hp_dict = json.load(f)
        hparams = deep_q_networks.get_hparams(**hp_dict)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=0.0)

    tf.reset_default_graph()
    optimized_mol = []
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
        for mol in all_mols:
            logging.info('Eval: %s', mol)
            environment = molecules_mdp.Molecule(
                atom_types=set(hparams.atom_types),
                init_mol=mol,
                allow_removal=hparams.allow_removal,
                allow_no_modification=hparams.allow_no_modification,
                allow_bonds_between_rings=hparams.allow_bonds_between_rings,
                allowed_ring_sizes=set(hparams.allowed_ring_sizes),
                max_steps=hparams.max_steps_per_episode,
                record_path=True)
            environment.initialize()
            if hparams.num_bootstrap_heads:
                head = np.random.randint(hparams.num_bootstrap_heads)
            else:
                head = 0
            for _ in range(hparams.max_steps_per_episode):
                steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
                valid_actions = list(environment.get_valid_actions())
                observations = np.vstack([
                    np.append(deep_q_networks.get_fingerprint(act, hparams),
                              steps_left) for act in valid_actions
                ])
                action = valid_actions[dqn.get_action(observations,
                                                      head=head,
                                                      update_epsilon=0.0)]
                environment.step(action)
            optimized_mol.append(environment.get_path())
    return optimized_mol
    def _restore(self, sess):
        """Restore this evaluator's checkpoint into a tf.Session."""
        global_step = self._create_global_step()
        tf.Saver().restore(sess, self.checkpoint_path)

        if self.global_step is None:
            self.global_step = global_step.eval(sess)
Пример #3
0
def run_training(hparams, environment, dqn):
    """Runs the training procedure.

  Briefly, the agent runs the action network to get an action to take in
  the environment. The state transition and reward are stored in the memory.
  Periodically the agent samples a batch of samples from the memory to
  update(train) its Q network. Note that the Q network and the action network
  share the same set of parameters, so the action network is also updated by
  the samples of (state, action, next_state, reward) batches.


  Args:
    hparams: tf.HParams. The hyper parameters of the model.
    environment: molecules.Molecule. The environment to run on.
    dqn: An instance of the DeepQNetwork class.

  Returns:
    None
  """
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
    tf.reset_default_graph()
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        # The schedule for the epsilon in epsilon greedy policy.
        exploration = schedules.PiecewiseSchedule(
            [(0, 1.0), (int(hparams.num_episodes / 2), 0.1),
             (hparams.num_episodes, 0.01)],
            outside_value=0.01)
        if hparams.prioritized:
            memory = replay_buffer.PrioritizedReplayBuffer(
                hparams.replay_buffer_size, hparams.prioritized_alpha)
            beta_schedule = schedules.LinearSchedule(
                hparams.num_episodes,
                initial_p=hparams.prioritized_beta,
                final_p=0)
        else:
            memory = replay_buffer.ReplayBuffer(hparams.replay_buffer_size)
            beta_schedule = None
        sess.run(tf.global_variables_initializer())
        sess.run(dqn.update_op)
        global_step = 0
        for episode in range(hparams.num_episodes):
            global_step = _episode(environment=environment,
                                   dqn=dqn,
                                   memory=memory,
                                   episode=episode,
                                   global_step=global_step,
                                   hparams=hparams,
                                   summary_writer=summary_writer,
                                   exploration=exploration,
                                   beta_schedule=beta_schedule)
            if (episode + 1) % hparams.update_frequency == 0:
                sess.run(dqn.update_op)
            if (episode + 1) % hparams.save_frequency == 0:
                model_saver.save(sess,
                                 os.path.join(FLAGS.model_dir, 'ckpt'),
                                 global_step=global_step)
Пример #4
0
def run_test():
  """Estimates the homography between two input images.
  """
  image1 = cv2.imread(FLAGS.image1)
  image2 = cv2.imread(FLAGS.image2)
  image_list = [image1, image2]
  image_norm_list = []
  for i in range(2):
    if FLAGS.network_id == 'fmask_sem':
      image_scale = cv2.resize(image_list[i],
                               (FLAGS.train_width, FLAGS.train_height),
                               cv2.INTER_LANCZOS4)
    else:
      image_gray = cv2.cvtColor(image_list[i], cv2.COLOR_BGR2GRAY)
      image_scale = cv2.resize(image_gray,
                               (FLAGS.train_width, FLAGS.train_height),
                               cv2.INTER_LANCZOS4)

    image_norm = image_scale / 256.0 - 0.5
    image_norm_list.append(image_norm)
  if FLAGS.network_id == 'fmask_sem':
    norm_image_pair = np.expand_dims(np.concatenate(image_norm_list, 2), axis=0)
    num_channel = 3
  else:
    norm_image_pair = np.expand_dims(np.stack(image_norm_list, -1), axis=0)
    num_channel = 1

  batch_pairs = tf.placeholder(tf.float32,
                               [1, FLAGS.train_height, FLAGS.train_width,
                                2 * num_channel])
  with slim.arg_scope(models.homography_arg_scope()):
    if FLAGS.network_id == 'fmask_sem':
      batch_hmg_prediction, _ = models.hier_homography_fmask_estimator(
          batch_pairs, num_param=8, num_layer=FLAGS.num_layer,
          num_level=FLAGS.num_level, is_training=False)
    else:
      batch_hmg_prediction, _ = models.hier_homography_estimator(
          batch_pairs, num_param=8, num_layer=FLAGS.num_layer,
          num_level=FLAGS.num_level, is_training=False)

  batch_warped_result, _ = hmg_util.homography_warp_per_batch(
      batch_pairs[Ellipsis, 0 : num_channel],
      batch_hmg_prediction[FLAGS.num_level - 1])

  saver = tf.Saver()
  with tf.Session() as sess:
    saver.restore(sess, FLAGS.model_path)
    image_warp, homography_list = sess.run(
        [batch_warped_result, batch_hmg_prediction],
        feed_dict={batch_pairs: norm_image_pair})
    for i in range(8):
      logging.info('%f ', homography_list[FLAGS.num_level - 1][0][i])
    cv2.imwrite('%s/input0.jpg' % FLAGS.out_dir,
                (image_norm_list[0] + 0.5) * 256)
    cv2.imwrite('%s/input1.jpg' % FLAGS.out_dir,
                (image_norm_list[1] + 0.5) * 256)
    cv2.imwrite('%s/result.jpg' % FLAGS.out_dir, (image_warp[0] + 0.5) * 256)
Пример #5
0
def run_training(hps,
                 experiment_proto,
                 train_dir,
                 train_input_paths,
                 val_input_paths,
                 tuner=None,
                 master='',
                 metrics_targets=None,
                 metrics_measures=None):
  """Main training function.

  Trains the model given a directory to write to and a logfile to write to.

  Args:
    hps: tf.HParams with training parameters.
    experiment_proto: selection_pb2.Experiment proto for training.
    train_dir: str path to train directory.
    train_input_paths: List[str] giving paths to input sstables for training.
    val_input_paths: List[str] giving paths to input sstable(s) for validation.
    tuner: optional hp_tuner.HPTuner.
    master: optional string to pass to a tf.Supervisor.
    metrics_targets: String list of network targets to report metrics for.
    metrics_measures: Measurements about the performance of the network to
        report, e.g. 'auc/top_1p'.

  Returns:
    None.

  Raises:
    Error: if the hyperparamter combination in hps is infeasible and there is
    no tuner. (If the hyperparameter combination is infeasible and there is
    a tuner then the params are reported back to the tuner as infeasible.)
  """
  hps_infeasible, infeasible_reason = hps_is_infeasible(
      hps, experiment_proto.sequence_length)
  if hps_infeasible:
    if tuner:
      tuner.report_done(True, infeasible_reason)
      logger.info('report_done(infeasible=%r)', hps_infeasible)
      return
    else:
      raise Error('Hyperparams are infeasible: %s', infeasible_reason)

  logger.info('Starting training.')
  if tuner:
    logger.info('Using tuner: loaded HParams from Vizier')
  else:
    logger.info('No tuner: using default HParams')
  logger.info('experiment_proto: %s', experiment_proto)
  logger.info('train_dir: %s', train_dir)
  logger.info('train_input_paths[0]: %s', train_input_paths[0])
  logger.info('val_input_paths[0]: %s', val_input_paths[0])
  logger.info('%r', list(hps.values()))
  generationinfo.to_file(os.path.join(train_dir, 'geninfo.pbtxt'))
  with gfile.Open(os.path.join(train_dir, config.hparams_name), 'w') as f:
    f.write(str(hps.to_proto()))

  eval_size = hps.eval_size or None

  def make_subdir(subdirectory_mame):
    path = os.path.join(train_dir, subdirectory_mame)
    gfile.MakeDirs(path)
    return path

  logger.info('Computing preprocessing statistics')
  # TODO(shoyer): move this over into preprocessing instead?
  experiment_proto = dataset_stats.compute_experiment_statistics(
      experiment_proto,
      train_input_paths,
      os.path.join(
          hps.input_dir,
          six.ensure_str(
              config.wetlab_experiment_train_pbtxt_path[hps.val_fold]) +
          '.wstats'),
      preprocess_mode=hps.preprocess_mode,
      max_size=eval_size,
      logdir=make_subdir('compute-statistics'),
      save_stats=hps.save_stats)

  logging.info('Saving experiment proto with statistics')
  with gfile.Open(
      os.path.join(train_dir, config.wetlab_experiment_train_name), 'w') as f:
    f.write(str(experiment_proto))

  logger.debug(str(hps.to_proto()))
  logger.debug(hps.run_name)

  tr_entries = len(sstable.MergedSSTable(train_input_paths))
  logger.info('Training sstable size: %d', tr_entries)
  val_entries = len(sstable.MergedSSTable(val_input_paths))
  logger.info('Validation sstable size: %d', val_entries)

  epoch_size = hps.epoch_size or int(tr_entries * (1 + hps.ratio_random_dna))
  num_batches_per_epoch = int(float(epoch_size) / hps.mbsz)

  eval_ff.config_pandas_display(FLAGS.interactive_display)
  tr_evaluator = eval_ff.Evaluator(
      hps,
      experiment_proto,
      train_input_paths,
      make_subdir(config.experiment_training_dir),
      verbose=FLAGS.verbose_eval)
  val_evaluator = eval_ff.Evaluator(
      hps,
      experiment_proto,
      val_input_paths,
      make_subdir(config.experiment_validation_dir),
      verbose=FLAGS.verbose_eval)

  with tf.Graph().as_default():
    # we need to use the registered key 'hparams'
    tf.add_to_collection('hparams', hps)

    # TODO(shoyer): collect these into a Model class:
    dummy_inputs = data.dummy_inputs(
        experiment_proto,
        input_features=hps.input_features,
        kmer_k_max=hps.kmer_k_max,
        additional_output=six.ensure_str(hps.additional_output).split(','))
    output_layer = output_layers.create_output_layer(experiment_proto, hps)
    net = ff.FeedForward(dummy_inputs, output_layer.logit_axis, hps)

    trainer = FeedForwardTrainer(hps, net, output_layer, experiment_proto,
                                 train_input_paths)

    summary_writer = tf.SummaryWriter(make_subdir('training'), flush_secs=30)

    # TODO(shoyer): file a bug to figure out why write_version=2 (now the
    # default) doesn't work.
    saver = tf.Saver(write_version=1)

    # We are always the chief since we do not do distributed training.
    # Every replica with a different task id is completely independent and all
    # must be their own chief.
    sv = tf.Supervisor(
        logdir=train_dir,
        is_chief=True,
        summary_writer=summary_writer,
        save_summaries_secs=10,
        save_model_secs=180,
        saver=saver)

    logger.info('Preparing session')

    train_report_dir = os.path.join(train_dir, config.experiment_training_dir)
    cur_train_report = os.path.join(train_report_dir,
                                    config.experiment_report_name)
    best_train_report = os.path.join(train_report_dir,
                                     config.experiment_best_report_name)

    valid_report_dir = os.path.join(train_dir, config.experiment_validation_dir)
    cur_valid_report = os.path.join(valid_report_dir,
                                    config.experiment_report_name)
    best_valid_report = os.path.join(valid_report_dir,
                                     config.experiment_best_report_name)

    best_checkpoint = os.path.join(train_dir, 'model.ckpt-lowest_val_loss')
    best_checkpoint_meta = best_checkpoint + '.meta'
    best_epoch_file = os.path.join(train_dir, 'best_epoch.txt')

    with sv.managed_session(master) as sess:

      logger.info('Starting queue runners')
      sv.start_queue_runners(sess)

      def save_and_evaluate():
        """Save and evaluate the current model.

        Returns:
          path: the path string to the checkpoint.
          summary_df: pandas.DataFrame storing the evaluation result on the
            validation dataset with rows for each output name and columns for
            each metric value
        """
        logger.info('Saving model checkpoint')
        path = sv.saver.save(
            sess,
            sv.save_path,
            global_step=sv.global_step,
            write_meta_graph=True)
        tr_evaluator.run(path, eval_size)
        summary_df, _ = val_evaluator.run_and_report(
            tuner,
            path,
            eval_size,
            metrics_targets=metrics_targets,
            metrics_measures=metrics_measures)
        return path, summary_df

      def update_best_model(path, cur_epoch):
        """Update the records of the model with the lowest validation error.

        Args:
          path: the path to the checkpoint of the current model.
          cur_epoch: a integer of the current epoch
        """

        cur_checkpoint = path
        cur_checkpoint_meta = six.ensure_str(cur_checkpoint) + '.meta'

        gfile.Copy(cur_train_report, best_train_report, overwrite=True)
        gfile.Copy(cur_valid_report, best_valid_report, overwrite=True)
        gfile.Copy(cur_checkpoint, best_checkpoint, overwrite=True)
        gfile.Copy(cur_checkpoint_meta, best_checkpoint_meta, overwrite=True)
        with gfile.Open(best_epoch_file, 'w') as f:
          f.write(str(cur_epoch)+'\n')

      def compare_with_best_model(checkpoint_path, summary_df, cur_epoch):
        logger.info('Comparing current val loss with the best model')

        if not gfile.Exists(best_train_report):
          logger.info('No best model saved. Adding current model...')
          update_best_model(checkpoint_path, cur_epoch)
        else:
          with gfile.GFile(best_valid_report) as f:
            with xarray.open_dataset(f) as best_ds:
              best_ds.load()
          cur_loss = summary_df['loss'].loc['mean']
          best_loss = best_ds['loss'].mean('output')
          logger.info('Current val loss:%f', cur_loss)
          logger.info('The best val loss:%f', best_loss)
          if cur_loss < best_loss:
            logger.info(
                'Current model has lower loss. Updating the best model.')
            update_best_model(checkpoint_path, cur_epoch)
          else:
            logger.info('The best model has lower loss.')

      logger.info('Running eval before starting training')
      save_and_evaluate()

      try:
        for cur_epoch in trainer.train(sess, hps.epochs, num_batches_per_epoch):
          checkpoint_path, val_summary_df = save_and_evaluate()
          if (cur_epoch+1) % hps.epoch_interval_to_save_best == 0:
            compare_with_best_model(checkpoint_path, val_summary_df, cur_epoch)
          if tuner and tuner.should_trial_stop():
            break
      except eval_ff.TrainingDivergedException as error:
        logger.error('Training diverged: %s', str(error))
        infeasible = True
      else:
        infeasible = False

      logger.info('Saving final checkpoint')
      sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

  if tuner:
    # should be at the very end of execution, to avoid possible race conditions
    tuner.report_done(infeasible=infeasible)
    logger.info('report_done(infeasible=%r)', infeasible)

  logger.info('Done.')