Пример #1
0
def main(argv=()):
  del argv  # Unused.
  eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train_bkp')
  log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
                         'eval_%s' % FLAGS.eval_set)
  if not os.path.exists(eval_dir):
    os.makedirs(eval_dir)
  if not os.path.exists(log_dir):
    os.makedirs(log_dir)
  g = tf.Graph()

  with g.as_default():
    eval_params = FLAGS
    eval_params.batch_size = 1
    eval_params.step_size = FLAGS.num_views
    ###########
    ## model ##
    ###########
    model = model_ptn.model_PTN(eval_params)
    ##########
    ## data ##
    ##########
    eval_data = model.get_inputs(
        FLAGS.inp_dir,
        FLAGS.dataset_name,
        eval_params.eval_set,
        eval_params.batch_size,
        eval_params.image_size,
        eval_params.vox_size,
        is_training=False)
    inputs = model.preprocess_with_all_views(eval_data)
    ##############
    ## model_fn ##
    ##############
    model_fn = model.get_model_fn(is_training=False, run_projection=False)
    outputs = model_fn(inputs)
    #############
    ## metrics ##
    #############
    names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
    del names_to_values
    ################
    ## evaluation ##
    ################
    num_batches = eval_data['num_samples']
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=eval_dir,
        logdir=log_dir,
        num_evals=num_batches,
        eval_op=names_to_updates.values(),
        eval_interval_secs=FLAGS.eval_interval_secs)
Пример #2
0
def main(argv=()):
  del argv  # Unused.
  eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
  log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
                         'eval_%s' % FLAGS.eval_set)
  if not os.path.exists(eval_dir):
    os.makedirs(eval_dir)
  if not os.path.exists(log_dir):
    os.makedirs(log_dir)
  g = tf.Graph()

  with g.as_default():
    eval_params = FLAGS
    eval_params.batch_size = 1
    eval_params.step_size = FLAGS.num_views
    ###########
    ## model ##
    ###########
    model = model_ptn.model_PTN(eval_params)
    ##########
    ## data ##
    ##########
    eval_data = model.get_inputs(
        FLAGS.inp_dir,
        FLAGS.dataset_name,
        eval_params.eval_set,
        eval_params.batch_size,
        eval_params.image_size,
        eval_params.vox_size,
        is_training=False)
    inputs = model.preprocess_with_all_views(eval_data)
    ##############
    ## model_fn ##
    ##############
    model_fn = model.get_model_fn(is_training=False, run_projection=False)
    outputs = model_fn(inputs)
    #############
    ## metrics ##
    #############
    names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
    del names_to_values
    ################
    ## evaluation ##
    ################
    num_batches = eval_data['num_samples']
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=eval_dir,
        logdir=log_dir,
        num_evals=num_batches,
        eval_op=names_to_updates.values(),
        eval_interval_secs=FLAGS.eval_interval_secs)
Пример #3
0
def main(_):
    train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
                             'train_bkp')
    save_image_dir = os.path.join(train_dir, 'images')
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
    if not os.path.exists(save_image_dir):
        os.makedirs(save_image_dir)

    g = tf.Graph()
    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            global_step = slim.get_or_create_global_step()
            ###########
            ## model ##
            ###########
            model = model_ptn.model_PTN(FLAGS)
            ##########
            ## data ##
            ##########
            train_data = model.get_inputs(FLAGS.inp_dir,
                                          FLAGS.dataset_name,
                                          'train_bkp',
                                          FLAGS.batch_size,
                                          FLAGS.image_size,
                                          FLAGS.vox_size,
                                          is_training=True)
            inputs = model.preprocess(train_data, FLAGS.step_size)
            ##############
            ## model_fn ##
            ##############
            model_fn = model.get_model_fn(is_training=True,
                                          reuse=False,
                                          run_projection=True)
            outputs = model_fn(inputs)
            ##################
            ## train_scopes ##
            ##################
            if FLAGS.init_model:
                train_scopes = ['decoder']
                init_scopes = ['encoder']
            else:
                train_scopes = ['encoder', 'decoder']

            ##########
            ## loss ##
            ##########
            task_loss = model.get_loss(inputs, outputs)

            regularization_loss = model.get_regularization_loss(train_scopes)
            loss = task_loss + regularization_loss
            ###############
            ## optimizer ##
            ###############
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            if FLAGS.sync_replicas:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=FLAGS.workers_replicas -
                    FLAGS.backup_workers,
                    total_num_replicas=FLAGS.worker_replicas)

            ##############
            ## train_op ##
            ##############
            train_op = model.get_train_op_for_scope(loss, optimizer,
                                                    train_scopes)
            ###########
            ## saver ##
            ###########
            saver = tf.train.Saver(
                max_to_keep=np.minimum(5, FLAGS.worker_replicas + 1))

            if FLAGS.task == 0:
                params = FLAGS
                params.batch_size = params.num_views
                params.step_size = 1
                model.set_params(params)
                val_data = model.get_inputs(params.inp_dir,
                                            params.dataset_name,
                                            'val',
                                            params.batch_size,
                                            params.image_size,
                                            params.vox_size,
                                            is_training=False)
                val_inputs = model.preprocess(val_data, params.step_size)
                # Note: don't compute loss here
                reused_model_fn = model.get_model_fn(is_training=False,
                                                     reuse=True)
                val_outputs = reused_model_fn(val_inputs)

                with tf.device(tf.DeviceSpec(device_type='CPU')):
                    vis_input_images = val_inputs['images_1'] * 255.0
                    vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0
                    vis_pred_projs = (val_outputs['projs_1'] *
                                      (-1) + 1) * 255.0

                    vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3)
                    vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3)
                    # rescale
                    new_size = [FLAGS.image_size] * 2
                    vis_gt_projs = tf.image.resize_nearest_neighbor(
                        vis_gt_projs, new_size)
                    vis_pred_projs = tf.image.resize_nearest_neighbor(
                        vis_pred_projs, new_size)
                    # flip
                    # vis_gt_projs = utils.image_flipud(vis_gt_projs)
                    # vis_pred_projs = utils.image_flipud(vis_pred_projs)
                    # vis_gt_projs is of shape [batch, height, width, channels]
                    write_disk_op = model.write_disk_grid(
                        global_step=global_step,
                        log_dir=save_image_dir,
                        input_images=vis_input_images,
                        gt_projs=vis_gt_projs,
                        pred_projs=vis_pred_projs,
                        input_voxels=val_inputs['voxels'],
                        output_voxels=val_outputs['voxels_1'])
                with tf.control_dependencies([write_disk_op]):
                    train_op = tf.identity(train_op)

            #############
            ## init_fn ##
            #############
            if FLAGS.init_model:
                init_fn = model.get_init_fn(init_scopes)
            else:
                init_fn = None

            ##############
            ## training ##
            ##############
            slim.learning.train(train_op=train_op,
                                logdir=train_dir,
                                init_fn=init_fn,
                                master=FLAGS.master,
                                is_chief=(FLAGS.task == 0),
                                number_of_steps=FLAGS.max_number_of_steps,
                                saver=saver,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Пример #4
0
def main(_):
  train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
  save_image_dir = os.path.join(train_dir, 'images')
  if not os.path.exists(train_dir):
    os.makedirs(train_dir)
  if not os.path.exists(save_image_dir):
    os.makedirs(save_image_dir)

  g = tf.Graph()
  with g.as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      global_step = slim.get_or_create_global_step()
      ###########
      ## model ##
      ###########
      model = model_ptn.model_PTN(FLAGS)
      ##########
      ## data ##
      ##########
      train_data = model.get_inputs(
          FLAGS.inp_dir,
          FLAGS.dataset_name,
          'train',
          FLAGS.batch_size,
          FLAGS.image_size,
          FLAGS.vox_size,
          is_training=True)
      inputs = model.preprocess(train_data, FLAGS.step_size)
      ##############
      ## model_fn ##
      ##############
      model_fn = model.get_model_fn(
          is_training=True, reuse=False, run_projection=True)
      outputs = model_fn(inputs)
      ##################
      ## train_scopes ##
      ##################
      if FLAGS.init_model:
        train_scopes = ['decoder']
        init_scopes = ['encoder']
      else:
        train_scopes = ['encoder', 'decoder']

      ##########
      ## loss ##
      ##########
      task_loss = model.get_loss(inputs, outputs)

      regularization_loss = model.get_regularization_loss(train_scopes)
      loss = task_loss + regularization_loss
      ###############
      ## optimizer ##
      ###############
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      if FLAGS.sync_replicas:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
            total_num_replicas=FLAGS.worker_replicas)

      ##############
      ## train_op ##
      ##############
      train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes)
      ###########
      ## saver ##
      ###########
      saver = tf.train.Saver(max_to_keep=np.minimum(5,
                                                    FLAGS.worker_replicas + 1))

      if FLAGS.task == 0:
        params = FLAGS
        params.batch_size = params.num_views
        params.step_size = 1
        model.set_params(params)
        val_data = model.get_inputs(
            params.inp_dir,
            params.dataset_name,
            'val',
            params.batch_size,
            params.image_size,
            params.vox_size,
            is_training=False)
        val_inputs = model.preprocess(val_data, params.step_size)
        # Note: don't compute loss here
        reused_model_fn = model.get_model_fn(is_training=False, reuse=True)
        val_outputs = reused_model_fn(val_inputs)

        with tf.device(tf.DeviceSpec(device_type='CPU')):
          vis_input_images = val_inputs['images_1'] * 255.0
          vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0
          vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0

          vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3)
          vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3)
          # rescale
          new_size = [FLAGS.image_size] * 2
          vis_gt_projs = tf.image.resize_nearest_neighbor(
              vis_gt_projs, new_size)
          vis_pred_projs = tf.image.resize_nearest_neighbor(
              vis_pred_projs, new_size)
          # flip
          # vis_gt_projs = utils.image_flipud(vis_gt_projs)
          # vis_pred_projs = utils.image_flipud(vis_pred_projs)
          # vis_gt_projs is of shape [batch, height, width, channels]
          write_disk_op = model.write_disk_grid(
              global_step=global_step,
              log_dir=save_image_dir,
              input_images=vis_input_images,
              gt_projs=vis_gt_projs,
              pred_projs=vis_pred_projs,
              input_voxels=val_inputs['voxels'],
              output_voxels=val_outputs['voxels_1'])
        with tf.control_dependencies([write_disk_op]):
          train_op = tf.identity(train_op)

      #############
      ## init_fn ##
      #############
      if FLAGS.init_model:
        init_fn = model.get_init_fn(init_scopes)
      else:
        init_fn = None

      ##############
      ## training ##
      ##############
      slim.learning.train(
          train_op=train_op,
          logdir=train_dir,
          init_fn=init_fn,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          number_of_steps=FLAGS.max_number_of_steps,
          saver=saver,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Пример #5
0
def main(argv=()):
  del argv  # Unused.
  #eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
  #log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name,
  #                       'eval_%s' % FLAGS.eval_set)
  #if not os.path.exists(eval_dir):
  #  os.makedirs(eval_dir)
  #if not os.path.exists(log_dir):
  #  os.makedirs(log_dir)
  g = tf.Graph()

  with g.as_default():
    eval_params = FLAGS
    eval_params.batch_size = 1
    eval_params.step_size = FLAGS.num_views
    ###########
    ## model ##
    ###########
    model = model_ptn.model_PTN(eval_params)
    ##########
    ## data ##
    ##########
    eval_data = model.get_inputs(
        FLAGS.inp_dir,
        FLAGS.dataset_name,
        eval_params.eval_set,
        eval_params.batch_size,
        eval_params.image_size,
        eval_params.vox_size,
        is_training=False)
    inputs = model.preprocess_with_all_views(eval_data)
    ##############
    ## model_fn ##
    ##############
    model_fn = model.get_model_fn(is_training=False, run_projection=False)
    outputs = model_fn(inputs)
    #############
    ## metrics ##
    #############
    names_to_values, names_to_updates = model.get_metrics(inputs, outputs)
    del names_to_values
    ################
    ## evaluation ##
    ################
    num_batches = eval_data['num_samples']

    sess = tf.Session()
    tf.train.start_queue_runners(sess=sess)
    saver = tf.train.Saver()

    def restore_from_checkpoint(sess, saver):
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
      if not ckpt or not ckpt.model_checkpoint_path:
        return False

      saver.restore(sess, ckpt.model_checkpoint_path)
      return True

    if not restore_from_checkpoint(sess, saver):
      raise NotImplementedError

    init = tf.global_variables_initializer()
    sess.run(init)
    init = tf.local_variables_initializer()
    sess.run(init)

    for i in range(num_batches):
      print('Running {} batch out of {} batches.'.format(i, num_batches))

      options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()

      sess.run(list(names_to_updates.values()), options=options,
               run_metadata=run_metadata)
      cg = CompGraph('ptn', run_metadata, tf.get_default_graph())

      cg_tensor_dict = cg.get_tensors()
      cg_sorted_keys = sorted(cg_tensor_dict.keys())
      cg_sorted_items = []
      for cg_key in cg_sorted_keys:
        cg_sorted_items.append(tf.shape(cg_tensor_dict[cg_key]))

      cg_sorted_shape = sess.run(cg_sorted_items)
      cg.op_analysis(dict(zip(cg_sorted_keys, cg_sorted_shape)),
                     'ptn.pickle')

      exit(0)