示例#1
0
  def __init__(self, split, synset_or_cat, mesh_hash=None, dynamic=False,
               verbose=True):
    self.split = split
    self.synset, self.cat = _parse_synset_or_cat(synset_or_cat)
    self.mesh_hash = mesh_hash
    self._rgb_path = None
    self._rgb_image = None
    self.__archive = None
    self._uniform_samples = None
    self._near_surface_samples = None
    self._grid = None
    self._world2grid = None
    self._gt_path = None
    self._tx = None
    self._gaps_to_occnet = None
    self._gt_mesh = None
    self._tx_path = None
    self._surface_samples = None
    self._normalized_gt_mesh = None
    self._r2n2_images = None
    self.depth_native_res = 224

    self.is_from_directory = False

    if dynamic:
      if verbose:
        log.verbose(
            'Using dynamic files, not checking ahead for file existence.')
    elif not file_util.exists(self.npz_path):
      raise ValueError('Expected a .npz at %s.' % self.npz_path)
    else:
      log.info(self.npz_path)
示例#2
0
    def __init__(self, model_dir, model_name, experiment_name):
        self.root = f'{model_dir}/{model_name}-{experiment_name}'
        self.model_name = model_name
        self.experiment_name = experiment_name

        if not file_util.exists(self.root):
            log.verbose('Regex expanding root to find experiment ID')
            options = file_util.glob(self.root[:-1] + '*')
            if len(options) != 1:
                log.verbose(
                    "Tried to glob for directory but didn't find one path. Found:"
                )
                log.verbose(options)
                raise ValueError('Directory not found: %s' % self.root)
            else:
                self.root = options[0] + '/'
                self.experiment_name = os.path.basename(self.root.strip('/'))
                self.experiment_name = self.experiment_name.replace(
                    self.model_name + '-', '')
                log.verbose('Expanded experiment name with regex to root: %s' %
                            self.root)

        job_strs = [
            os.path.basename(n) for n in file_util.glob(f'{self.root}/*')
        ]

        banned = ['log', 'mldash_config.txt', 'snapshot', 'mldash_config']
        job_strs = [p for p in job_strs if p not in banned]
        job_strs = sorted(job_strs, key=to_xid)
        log.verbose('Job strings: %s' % repr(job_strs))
        self.all_jobs = [Job(self, job_str) for job_str in job_strs]
        self._visible_jobs = self.all_jobs[:]
示例#3
0
 def root_dir(self):
     if self._root_dir is None:
         self._root_dir = '%s/%s/' % (self.experiment.root, self.job_str)
         if not file_util.exists(self._root_dir):
             raise ValueError("Couldn't find job directory at %s." %
                              self._root_dir)
     return self._root_dir
示例#4
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  root_out = FLAGS.occnet_dir + '/extracted'
  if not file_util.exists(root_out):
    file_util.mkdir(root_out)

  if FLAGS.write_metrics:
    # TODO(ldif-user): Set up your own pipeline runner
    # TODO(ldif-user) Replace lambda x: None with a proto reader.
    with beam.Pipeline() as p:
      protos = p | 'ReadResults' >> (lambda x: None)

      with_metrics = protos | 'ExtractMetrics' >> beam.FlatMap(
          make_metrics)
      result_pcoll = with_metrics | 'MakeMetricList' >> (
          beam.combiners.ToList())
      result_str = result_pcoll | 'MakeMetricStr' >> beam.Map(
          save_metrics)
      out_path = FLAGS.occnet_dir + '/extracted/metrics_ub-v2.csv'
      _ = result_str | 'WriteMetrics' >> beam.io.WriteToText(
          out_path, num_shards=1, shard_name_template='')
  if FLAGS.write_metric_summaries:
    log.info('Aggregating results locally.')
    result_path = FLAGS.occnet_dir + '/extracted/metrics_ub-v2.csv'
    final_results = metrics.aggregate_extracted(result_path)
    summary_out_path = result_path.replace('/metrics_ub-v2.csv',
                                           '/metric_summary_ub-v2.csv')
    file_util.writetxt(summary_out_path, final_results.to_csv())
示例#5
0
def make_metrics(proto):
  """Returns a single-element list containing a dictionary of metrics."""
  key, s = proto
  p = results_pb2.Results.FromString(s)
  mesh_path = f"{FLAGS.occnet_dir}{key.replace('test/', '')}.ply"
  log.warning('Mesh path: %s' % mesh_path)
  try:
    mesh = file_util.read_mesh(mesh_path)
    _, synset, mesh_hash = key.split('/')
    if FLAGS.transform:
      ex = example.InferenceExample('test', synset, mesh_hash)
      tx = ex.gaps_to_occnet
      mesh.apply_transform(tx)
    log.info('Succeeded on %s' % mesh_path)
  # pylint:disable=broad-except
  except Exception as e:
    # pylint:enable=broad-except
    log.error(f"Couldn't load {mesh_path}, skipping due to {repr(e)}.")
    return []

  gt_mesh = mesh_util.deserialize(p.gt_mesh)
  dir_out = FLAGS.occnet_dir + '/metrics-out-gt/%s' % key
  if not file_util.exists(dir_out):
    file_util.makedirs(dir_out)
  file_util.write_mesh(f'{dir_out}gt_mesh.ply', gt_mesh)
  file_util.write_mesh(f'{dir_out}occnet_pred.ply', mesh)

  nc, fst, fs2t, chamfer = metrics.all_mesh_metrics(mesh, gt_mesh)
  return [{
      'key': key,
      'Normal Consistency': nc,
      'F-Score (tau)': fst,
      'F-Score (2*tau)': fs2t,
      'Chamfer': chamfer,
  }]
示例#6
0
def make_metrics(proto):
  """Builds a dictionary containing proto elements."""
  key, s = proto
  p = results_pb2.Results.FromString(s)
  mesh_path = FLAGS.occnet_dir + key.replace('test/', '') + '.ply'
  log.warning('Mesh path: %s' % mesh_path)
  try:
    mesh = file_util.read_mesh(mesh_path)
    if FLAGS.transform:
      # TODO(ldif-user) Set up the path to the transformation:
      tx_path = 'ROOT_DIR/%s/occnet_to_gaps.txt' % key
      occnet_to_gaps = file_util.read_txt_to_np(tx_path).reshape([4, 4])
      gaps_to_occnet = np.linalg.inv(occnet_to_gaps)
      mesh.apply_transform(gaps_to_occnet)
  # pylint: disable=broad-except
  except Exception as e:
    # pylint: enable=broad-except
    log.error("Couldn't load %s, skipping due to %s." % (mesh_path, repr(e)))
    return []

  gt_mesh = mesh_util.deserialize(p.gt_mesh)
  dir_out = FLAGS.occnet_dir + '/metrics-out-gt/%s' % key
  if not file_util.exists(dir_out):
    file_util.makedirs(dir_out)
  file_util.write_mesh(f'{dir_out}gt_mesh.ply', gt_mesh)
  file_util.write_mesh(f'{dir_out}occnet_pred.ply', mesh)

  nc, fst, fs2t, chamfer = metrics.all_mesh_metrics(mesh, gt_mesh)
  return [{
      'key': key,
      'Normal Consistency': nc,
      'F-Score (tau)': fst,
      'F-Score (2*tau)': fs2t,
      'Chamfer': chamfer,
  }]
示例#7
0
 def remote_result_base(self, xid, split):
     """Returns the base path for remote results for the xid-split pair."""
     if xid in self.remote_base_dirs[split]:
         return self.remote_base_dirs[split][xid]
     ckpt = self.ckpt_for_xid(xid, split)
     s = '%s/%i/%s' % (self.remote_result_ckpt_dir(xid), ckpt, split)
     # Ensure it's a valid directory:
     if not file_util.exists(s):
         raise ValueError(
             ('No directory for split %s and ckpt %i for for xid %i.'
              ' Expected path was: %s') % (split, ckpt, xid, s))
     self.remote_base_dirs[split][xid] = s
     return s
示例#8
0
 def copy_meshes(self, mesh_names, split, overwrite_if_present=True):
     """Copies meshes from the remote store to the local cache."""
     for xid in self.experiment.visible_xids:
         log.verbose('Copying filelist for xid #%i...' % xid)
         for mesh_name in mesh_names:
             local_path = self.local_path_to_mesh(mesh_name, xid, split)
             if os.path.isfile(local_path) and not overwrite_if_present:
                 continue
             local_dir = os.path.dirname(local_path)
             if not os.path.isdir(local_dir):
                 os.makedirs(local_dir)
             remote_path = self.remote_path_to_mesh(mesh_name, xid, split)
             if file_util.exists(local_path):
                 file_util.rm(local_path)
             file_util.cp(remote_path, local_path)
示例#9
0
 def hparams(self):
     """Load a tf.HParams() object based on the serialized hparams file."""
     if self._hparams is None:
         hparam_path = '%s/train/hparam_pickle.txt' % self.root_dir
         if file_util.exists(hparam_path):
             log.info('Found serialized hparams. Loading from %s' %
                      hparam_path)
             # hparams = hparams_util.read_hparams(hparam_path)
             self._hparams = (
                 hparams_util.
                 read_hparams_with_new_backwards_compatible_additions(
                     hparam_path))
         else:
             raise ValueError('No serialized hparam file found at %s' %
                              hparam_path)
     return self._hparams
示例#10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    xids = inference_util.parse_xid_str(FLAGS.xids)

    root_out = FLAGS.input_dir + '/extracted'
    if not file_util.exists(root_out):
        file_util.mkdir(root_out)

    if FLAGS.write_metrics or FLAGS.write_results:
        # TODO(ldif-user): Set up your own pipeline runner
        with beam.Pipeline() as p:
            for xid in xids:
                name = 'XID%i' % xid
                path = get_result_path(xid)
                # TODO(ldif-user) Replace lambda x: None with a proto reader.
                protos = p | 'ReadResults%s' % name >> (lambda x: None)

                if FLAGS.write_results:
                    map_fun = functools.partial(write_results, xid=xid)
                    _ = protos | 'ExtractResults%s' % name >> beam.FlatMap(
                        map_fun)
                if FLAGS.write_metrics:
                    with_metrics = protos | 'ExtractMetrics%s' % name >> beam.Map(
                        make_metrics)
                    result_pcoll = with_metrics | 'MakeMetricList%s' % name >> (
                        beam.combiners.ToList())
                    result_str = result_pcoll | 'MakeMetricStr%s' % name >> beam.Map(
                        save_metrics)
                    out_path = FLAGS.input_dir + '/extracted/%s_metrics-v2.csv' % name
                    _ = result_str | 'WriteMetrics%s' % name >> beam.io.WriteToText(
                        out_path, num_shards=1, shard_name_template='')
    if FLAGS.write_metric_summaries:
        log.info('Aggregating results locally.')
        for xid in tqdm.tqdm(xids):
            result_path = FLAGS.input_dir + '/extracted/XID%i_metrics-v2.csv' % xid
            final_results = metrics.aggregate_extracted(result_path)
            summary_out_path = result_path.replace('_metrics-v2.csv',
                                                   '_metric_summary-v2.csv')
            file_util.writetxt(summary_out_path, final_results.csv())
示例#11
0
    def eval_step(session, global_step, desired_num_examples, eval_tag,
                  eval_checkpoint):
        """Runs a single eval step.

    Runs over the full desired eval set, regardless of batch size.

    Args:
      session: A tf.Session instance.
      global_step: The global step tensor.
      desired_num_examples: The number of examples from the eval dataset to
        evaluate.
      eval_tag: A tag to specify the eval type. Defaults to 'eval'.
      eval_checkpoint: The path of the checkpoint being evaluated.

    Returns:
      A list of tf.Summary objects computed during the eval.
    """
        step_start_time = time.time()
        del eval_tag, desired_num_examples
        global_step_int = int(global_step)
        # num_batches = max(1, desired_num_examples // model_config.hparams.bs)
        big_render_images = []
        all_centers_np = []
        all_radii_np = []
        all_constants_np = []
        all_quadrics_np = []
        all_iparams_np = []
        all_mesh_names_np = []
        all_depth_images_np = []
        tf.logging.info('The eval checkpoint str is %s', eval_checkpoint)

        eval_dir = '/'.join(eval_checkpoint.split('/')[:-1])

        hparam_path = eval_dir + '/hparam_pickle.txt'
        if not file_util.exists(hparam_path):
            hparams.write_hparams(model_config.hparams, hparam_path)
        output_dir = (eval_dir + '/eval-step-' + str(global_step_int) + '/')

        def to_uint8(np_im):
            return (np.clip(255.0 * np_im, 0, 255.0)).astype(np.uint8)

        ran_count = 0
        max_run_count = 500
        ious = np.zeros(max_run_count, dtype=np.float32)
        # Run until the end of the dataset:
        for vi in range(max_run_count):
            tf.logging.info(
                'Starting eval item %i, total elapsed time is %0.2f...', vi,
                time.time() - step_start_time)
            try:
                vis_start_time = time.time()
                if vi < vis_count:
                    misc_tensors_to_eval = [
                        model_config.summary_op,
                        optimization_target,
                        sample_locations,
                        in_out_image_big,
                        training_example.mesh_name,
                        example_iou,
                    ]
                    np_out = session.run(
                        misc_tensors_to_eval +
                        prediction.structured_implicit.tensor_list)
                    (summaries, optimization_target_np, samples_np,
                     in_out_image_big_np, mesh_names_np,
                     example_iou_np) = np_out[:len(misc_tensors_to_eval)]
                    in_out_image_big_np = np.reshape(in_out_image_big_np,
                                                     [256, 256, 1])
                    in_out_image_big_np = image_util.get_pil_formatted_image(
                        in_out_image_big_np)
                    tf.logging.info('in_out_image_big_np shape: %s',
                                    str(in_out_image_big_np.shape))
                    in_out_image_big_np = np.reshape(in_out_image_big_np,
                                                     [1, 256, 256, 4])
                    implicit_np_list = np_out[len(misc_tensors_to_eval):]
                    tf.logging.info('\tElapsed after first sess run: %0.2f',
                                    time.time() - vis_start_time)
                else:
                    np_out = session.run(
                        [training_example.mesh_name, example_iou] +
                        prediction.structured_implicit.tensor_list)
                    mesh_names_np = np_out[0]
                    example_iou_np = np_out[1]
                    implicit_np_list = np_out[2:]

                # TODO(kgenova) It would be nice to move all this functionality into
                # a numpy StructuredImplicitNp class, and hide these internals.

                ious[ran_count] = example_iou_np
                ran_count += 1

                constants_np, centers_np, radii_np = implicit_np_list[:3]
                if len(implicit_np_list) == 4:
                    iparams_np = implicit_np_list[3]
                else:
                    iparams_np = None
                # For now, just map to quadrics and move on:
                quadrics_np = np.zeros(
                    [constants_np.shape[0], constants_np.shape[1], 4, 4])
                quadrics_np[0, :, 3, 3] = np.reshape(constants_np[0, :], [
                    model_config.hparams.sc,
                ])

                all_centers_np.append(np.copy(centers_np))
                all_radii_np.append(np.copy(radii_np))
                all_constants_np.append(np.copy(constants_np))
                all_quadrics_np.append(np.copy(quadrics_np))
                all_mesh_names_np.append(mesh_names_np)
                if iparams_np is not None:
                    all_iparams_np.append(iparams_np)

                # For most of the dataset, just do inference to get the representation.
                # Everything afterwards is just for tensorboard.
                if vi >= vis_count:
                    continue

                visualize_with_marching_cubes = False
                if visualize_with_marching_cubes:
                    # TODO(kgenova) This code is quite wrong now. If we want to enable it
                    # it should be rewritten to call a structured_implicit_function to
                    # handle evaluation (for instance the lset subtraction is bad).
                    marching_cubes_ims_np, output_volumes = np_util.visualize_prediction(
                        quadrics_np,
                        centers_np,
                        radii_np,
                        renormalize=model_config.hparams.pou == 't',
                        thresh=model_config.hparams.lset)
                    tf.logging.info(
                        '\tElapsed after first visualize_prediction: %0.2f',
                        time.time() - vis_start_time)
                    offset_marching_cubes_ims_np, _ = np_util.visualize_prediction(
                        quadrics_np,
                        centers_np,
                        radii_np,
                        renormalize=model_config.hparams.pou == 't',
                        thresh=0.1,
                        input_volumes=output_volumes)
                    tf.logging.info(
                        '\tElapsed after second visualize_prediction: %0.2f',
                        time.time() - vis_start_time)
                    tf.logging.info('About to concatenate shapes: %s, %s, %s',
                                    str(in_out_image_big_np.shape),
                                    str(marching_cubes_ims_np.shape),
                                    str(offset_marching_cubes_ims_np.shape))
                    in_out_image_big_np = np.concatenate([
                        in_out_image_big_np, marching_cubes_ims_np,
                        offset_marching_cubes_ims_np
                    ],
                                                         axis=2)

                if do_iterative_update:
                    # This code will fail (it's left unasserted to give a helpful tf error
                    # message). The tensor it creates will now be the wrong size.
                    render_summary, iterated_render_np = refine.refine(
                        structured_implicit_ph, optimization_target_ph,
                        samples_ph, original_vis_ph, gradients,
                        implicit_np_list, optimization_target_np, samples_np,
                        in_out_image_big_np, session, render_summary_op,
                        iterated_render_tf)
                    render_summary = [render_summary]
                    in_out_with_iterated = np.concatenate(
                        [in_out_image_big_np, iterated_render_np], axis=2)
                    big_render_images.append(to_uint8(in_out_with_iterated))
                else:
                    big_render_images.append(to_uint8(in_out_image_big_np))

                    # TODO(kgenova) Is this really the depth image?
                    depth_image_np = in_out_image_big_np[:, :, :256, :]
                    all_depth_images_np.append(depth_image_np)

                    render_summary = []
            except tf.errors.OutOfRangeError:
                break
        tf.logging.info('Elapsed after vis loop: %0.2f',
                        time.time() - step_start_time)
        ious = ious[:ran_count]
        mean_iou_summary, iou_hist_summary = session.run(
            [mean_iou_summary_op, iou_hist_summary_op],
            feed_dict={iou_ph: ious})

        all_centers_np = np.concatenate(all_centers_np)
        all_radii_np = np.concatenate(all_radii_np)
        all_constants_np = np.concatenate(all_constants_np)
        all_quadrics_np = np.concatenate(all_quadrics_np)
        all_mesh_names_np = np.concatenate(all_mesh_names_np)
        all_depth_images_np = np.concatenate(all_depth_images_np)
        if all_iparams_np:
            all_iparams_np = np.concatenate(all_iparams_np)

        file_util.mkdir(output_dir, exist_ok=True)
        file_util.write_np(
            '%s/%s-constants.npy' % (output_dir, training_example.split),
            all_constants_np)
        file_util.write_np(
            '%s/%s-quadrics.npy' % (output_dir, training_example.split),
            all_quadrics_np)
        file_util.write_np(
            '%s/%s-centers.npy' % (output_dir, training_example.split),
            all_centers_np)
        file_util.write_np(
            '%s/%s-radii.npy' % (output_dir, training_example.split),
            all_radii_np)
        file_util.write_np(
            '%s/%s-mesh_names.npy' % (output_dir, training_example.split),
            all_mesh_names_np)
        # We do an explicit comparison because the type of all_iparams_np might
        # not be a list at this point:
        # pylint: disable=g-explicit-bool-comparison
        if all_iparams_np != []:
            file_util.write_np(
                '%s/%s-iparams.npy' % (output_dir, training_example.split),
                all_iparams_np)

        # Now that the full set predictions have been saved to disk, scrap
        # everything after the first vis_count:
        all_centers_np = all_centers_np[:vis_count, ...]
        all_radii_np = all_radii_np[:vis_count, ...]
        all_constants_np = all_constants_np[:vis_count, ...]
        all_mesh_names_np = all_mesh_names_np[:vis_count, ...]

        tf.logging.info('Elapsed after .npy save: %0.2f',
                        time.time() - step_start_time)

        rbf_renders_at_half = np_util.plot_rbfs_at_thresh(all_centers_np,
                                                          all_radii_np,
                                                          thresh=0.5)
        rbf_renders_at_half_summary = session.run(
            rbf_render_at_half_summary_op,
            feed_dict={rbf_render_at_half_ph: rbf_renders_at_half})
        tf.logging.info('Elapsed after rbf_at_half summary: %0.2f',
                        time.time() - step_start_time)
        tf.logging.info('All depth images shape: %s',
                        str(all_depth_images_np.shape))
        depth_gt_out_summary = session.run(
            depth_gt_out_summary_op,
            feed_dict={
                depth_gt_out_ph:
                np.concatenate([all_depth_images_np, all_depth_images_np],
                               axis=2)
            })
        tf.logging.info('Elapsed after depth_gt_out summary: %0.2f',
                        time.time() - step_start_time)

        big_render_summary = session.run(big_render_summary_op,
                                         feed_dict={
                                             big_render_ph:
                                             np.concatenate(big_render_images,
                                                            axis=0)
                                         })
        tf.logging.info('Evaluated %d batches of size %d.', vis_count,
                        model_config.hparams.bs)
        tf.logging.info('Elapsed at end of step: %0.2f',
                        time.time() - step_start_time)
        return [
            summaries, big_render_summary, rbf_renders_at_half_summary,
            depth_gt_out_summary, mean_iou_summary, iou_hist_summary
        ] + render_summary
示例#12
0
文件: train.py 项目: chengjieniu/ldif
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    tf.disable_v2_behavior()
    log.set_level(FLAGS.log_level)

    log.info('Making dataset...')
    if not FLAGS.dataset_directory:
        raise ValueError('A dataset directory must be provided.')
    # TODO(kgenova) This batch size should match.
    dataset = local_inputs.make_dataset(FLAGS.dataset_directory,
                                        mode='train',
                                        batch_size=FLAGS.batch_size,
                                        split=FLAGS.split)

    # Sets up the hyperparameters and tf.Dataset
    model_config = build_model_config(dataset)

    # Generates the graph for a single train step, including summaries
    shared_launcher.sif_transcoder(model_config)
    summary_op = tf.summary.merge_all()
    global_step_op = tf.compat.v1.train.get_global_step()

    saver = tf.train.Saver(max_to_keep=5,
                           pad_step_number=False,
                           save_relative_paths=True)

    init_op = tf.initialize_all_variables()

    model_root = get_model_root()

    experiment_dir = f'{model_root}/sif-transcoder-{FLAGS.experiment_name}'
    checkpoint_dir = f'{experiment_dir}/1-hparams/train/'

    if FLAGS.reserve_memory_for_inference_kernel and sys.platform != "darwin":
        current_free = gpu_util.get_free_gpu_memory(0)
        allowable = current_free - (1024 + 512)  # ~1GB
        allowable_fraction = allowable / current_free
        if allowable_fraction <= 0.0:
            raise ValueError(
                f"Can't leave 1GB over for the inference kernel, because"
                f" there is only {allowable} total free GPU memory.")
        log.info(
            f'TensorFlow can use up to {allowable_fraction*100}% of the total'
            ' GPU memory.')
    else:
        allowable_fraction = 1.0
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=allowable_fraction)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:
        writer = tf.summary.FileWriter(f'{experiment_dir}/log', session.graph)
        log.info('Initializing variables...')
        session.run([init_op])

        if FLAGS.visualize:
            visualize_data(session, model_config.inputs['dataset'])

        # Check whether the checkpoint directory already exists (resuming) or
        # needs to be created (new model).
        if not os.path.isdir(checkpoint_dir):
            log.info('No previous checkpoint detected, training from scratch.')
            os.makedirs(checkpoint_dir)
            # Serialize hparams so eval can load them:
            hparam_path = f'{checkpoint_dir}/hparam_pickle.txt'
            if not file_util.exists(hparam_path):
                hparams.write_hparams(model_config.hparams, hparam_path)
            initial_index = 0
        else:
            log.info(
                f'Checkpoint root {checkpoint_dir} exists, attempting to resume.'
            )
            latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            log.info(f'Latest checkpoint: {latest_checkpoint}')
            saver.restore(session, latest_checkpoint)
            initial_index = session.run(global_step_op)
            log.info(f'The global step is {initial_index}')
            initial_index = int(initial_index)
            log.info(f'Parsed to {initial_index}')
        for i in range(initial_index, FLAGS.train_step_count):
            start_time = time.time()
            log.info(f'Step {i}')
            is_summary_step = i % FLAGS.summary_step_interval == 0
            if is_summary_step:
                _, summaries, loss = session.run(
                    [model_config.train_op, summary_op, model_config.loss])
                writer.add_summary(summaries, i)
            else:
                _, loss = session.run(
                    [model_config.train_op, model_config.loss])
            end_time = time.time()
            steps_per_second = 1.0 / (end_time - start_time)
            log.info(f'Loss: {loss}\tSteps/second: {steps_per_second}')

            is_checkpoint_step = i % FLAGS.checkpoint_interval == 0
            if is_checkpoint_step or i == FLAGS.train_step_count - 1:
                ckpt_path = os.path.join(checkpoint_dir, 'model.ckpt')
                log.info(f'Writing checkpoint to {ckpt_path}...')
                saver.save(session, ckpt_path, global_step=i)
        log.info('Done training!')
示例#13
0
文件: train.py 项目: trisct/ldif
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    tf.disable_v2_behavior()
    log.set_level(FLAGS.log_level)

    log.info('Making dataset...')
    if not FLAGS.dataset_directory:
        raise ValueError('A dataset directory must be provided.')
    if not os.path.isdir(FLAGS.dataset_directory):
        raise ValueError(
            f'No dataset directory found at {FLAGS.dataset_directory}')
    # TODO(kgenova) This batch size should match.
    dataset = local_inputs.make_dataset(FLAGS.dataset_directory,
                                        mode='train',
                                        batch_size=FLAGS.batch_size,
                                        split=FLAGS.split)

    # Sets up the hyperparameters and tf.Dataset
    model_config = build_model_config(dataset)
    #print('[HERE: In train] ******* Printing model_config, right after building model config')
    #print(type(model_config))
    #print(dir(model_config))
    #print('[HERE: In train] ******* Printing model_config done, right after building model config')

    # Generates the graph for a single train step, including summaries

    # shared_launcher.sif_transcoder sets more configs of model_config
    shared_launcher.sif_transcoder(model_config)
    print(
        '[HERE: In train] ******* Printing model_config, right after running shared_launcher'
    )
    print(type(model_config))
    print(dir(model_config))

    print('Type of model_config.train_op:', type(model_config.train_op))
    print('Type of model_config.loss:', type(model_config.loss))
    print('Losses used:', model_config.hparams.loss)
    print('Hparams:', model_config.hparams)
    # train_op is a tensor!
    print(
        '[HERE: In train] ******* Printing model_config done, right after running shared_launcher'
    )

    summary_op = tf.summary.merge_all()
    global_step_op = tf.compat.v1.train.get_global_step()

    saver = tf.train.Saver(max_to_keep=5,
                           pad_step_number=False,
                           save_relative_paths=True)

    init_op = tf.initialize_all_variables()

    model_root = get_model_root()

    experiment_dir = f'{model_root}/sif-transcoder-{FLAGS.experiment_name}'
    checkpoint_dir = f'{experiment_dir}/1-hparams/train/'

    if FLAGS.reserve_memory_for_inference_kernel and sys.platform != "darwin":
        print(
            '[HERE: In train] --reserve_memory_for_inference_kernel specified.'
        )

        current_free = gpu_util.get_free_gpu_memory(2)
        allowable = current_free - (1024 + 512)  # ~1GB
        allowable = min(allowable, 10000)
        allowable_fraction = allowable / current_free

        print('[HERE: In train] GPU memory usage planning:')
        #print('[HERE: In train] | allowable is limited to = 5000')
        print('[HERE: In train] | current_free = %d, allowable = %d' %
              (current_free, allowable))

        if allowable_fraction <= 0.0:
            raise ValueError(
                f"Can't leave 1GB over for the inference kernel, because"
                f" there is only {allowable} total free GPU memory.")
        log.info(
            f'TensorFlow can use up to {allowable_fraction*100}% of the total'
            ' GPU memory.')
    else:
        allowable_fraction = 1.0
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=allowable_fraction)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:
        #print('[HERE: In train] ******* Printing model_config, right after session creation')
        #print(type(model_config))
        #print(dir(model_config))
        #print('[HERE: In train] ******* Printing model_config done, right after session creation')

        writer = tf.summary.FileWriter(f'{experiment_dir}/log', session.graph)
        log.info('Initializing variables...')
        session.run([init_op])

        #print('[HERE: In train] ******* Printing model_config, right after session init')
        #print(type(model_config))
        #print(dir(model_config))
        #print('[HERE: In train] ******* Printing model_config done, right after session init')

        if FLAGS.visualize:
            visualize_data(session, model_config.inputs['dataset'])

        # Check whether the checkpoint directory already exists (resuming) or
        # needs to be created (new model).
        if not os.path.isdir(checkpoint_dir):
            log.info('No previous checkpoint detected, training from scratch.')
            os.makedirs(checkpoint_dir)
            # Serialize hparams so eval can load them:
            hparam_path = f'{checkpoint_dir}/hparam_pickle.txt'
            if not file_util.exists(hparam_path):
                hparams.write_hparams(model_config.hparams, hparam_path)
            initial_index = 0
        else:
            log.info(
                f'Checkpoint root {checkpoint_dir} exists, attempting to resume.'
            )
            latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            log.info(f'Latest checkpoint: {latest_checkpoint}')
            saver.restore(session, latest_checkpoint)
            initial_index = session.run(global_step_op)
            log.info(f'The global step is {initial_index}')
            initial_index = int(initial_index)
            log.info(f'Parsed to {initial_index}')

        print('[HERE: In train] Starting training...')
        start_time = time.time()
        log_every = 10

        print(
            '[HERE: In train] ******* Printing model_config, right before training loop starts'
        )
        print(type(model_config))
        print(dir(model_config))
        print(
            '[HERE: In train] ******* Printing model_config done, right before training loop starts'
        )

        for i in range(initial_index, FLAGS.train_step_count):
            print(
                '[HERE: In train] Starting training, within loop, before log verbose...'
            )
            log.verbose(f'Starting step {i}...')
            print(f'[HERE: In train] Starting step {i}...')
            print(
                '[HERE: In train] Starting training, within loop, after verbose...'
            )
            is_summary_step = i % FLAGS.summary_step_interval == 0

            # running the session to get the results
            if is_summary_step:
                #print('[HERE: In train] This is a summary step. Computing summaries and loss...')
                _, summaries, loss = session.run(
                    [model_config.train_op, summary_op, model_config.loss])
                writer.add_summary(summaries, i)
                print(
                    '[HERE: In train] This is a summary step. Done writing summaries and loss...'
                )
            else:
                print(
                    '[HERE: In train] This is not a summary step. Computing loss...'
                )
                _, loss = session.run(
                    [model_config.train_op, model_config.loss])
                print(
                    '[HERE: In train] This is not a summary step. Done computing loss...'
                )
            if not (i % log_every):
                print('[HERE: In train] This is a log step. Logging...')
                end_time = time.time()
                steps_per_second = float(log_every) / (end_time - start_time)
                start_time = end_time
                log.info(
                    f'Step: {i}\tLoss: {loss}\tSteps/second: {steps_per_second}'
                )
                print('[HERE: In train] This is a log step. Logging done...')

            is_checkpoint_step = i % FLAGS.checkpoint_interval == 0
            if is_checkpoint_step or i == FLAGS.train_step_count - 1:
                print(
                    '[HERE: In train] This is a saving checkpoint step. Saving model...'
                )
                ckpt_path = os.path.join(checkpoint_dir, 'model.ckpt')
                log.info(f'Writing checkpoint to {ckpt_path}...')
                saver.save(session, ckpt_path, global_step=i)
                print(
                    '[HERE: In train] This is a saving checkpoint step. Done saving model...'
                )

            print('[HERE: In train] This step done. Starting a new step...')
        log.info('Done training!')