예제 #1
0
def build_mixed_batch(t,
                      dataset3d,
                      dataset2d,
                      examples3d,
                      examples2d,
                      learning_phase,
                      batch_size3d=None,
                      batch_size2d=None,
                      shuffle=None,
                      rng=None,
                      max_unconsumed=256,
                      n_done_steps=0,
                      n_total_steps=None):
    if shuffle is None:
        shuffle = learning_phase == TRAIN

    if rng is None:
        rng = np.random.RandomState()
    rng_2d = util.new_rng(rng)
    rng_3d = util.new_rng(rng)

    (t.image_path_2d, t.x_2d, t.coords2d_true2d,
     t.joint_validity_mask2d) = helpers.build_input_batch(
         t,
         examples2d,
         data.data_loading.load_and_transform2d,
         (dataset2d.joint_info, learning_phase),
         learning_phase,
         batch_size2d,
         FLAGS.workers,
         shuffle=shuffle,
         rng=rng_2d,
         max_unconsumed=max_unconsumed,
         n_done_steps=n_done_steps,
         n_total_steps=n_total_steps)

    (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics,
     t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask,
     t.is_joint_in_fov, t.activity_name,
     t.scene_name) = helpers.build_input_batch(
         t,
         examples3d,
         data.data_loading.load_and_transform3d,
         (dataset3d.joint_info, learning_phase),
         learning_phase,
         batch_size3d,
         FLAGS.workers,
         shuffle=shuffle,
         rng=rng_3d,
         max_unconsumed=max_unconsumed,
         n_done_steps=n_done_steps,
         n_total_steps=n_total_steps)
예제 #2
0
def build_graph(learning_phase,
                n_epochs=None,
                shuffle=None,
                drop_remainder=None,
                rng=None,
                n_done_steps=0):
    tfu.set_is_training(learning_phase == TRAIN)

    t = AttrDict(global_step=tf.compat.v1.train.get_or_create_global_step())

    dataset3d = data.datasets3d.get_dataset(FLAGS.dataset)
    examples = helpers.get_examples(dataset3d, learning_phase, FLAGS)
    t.n_examples = len(examples)
    phase_name = 'training' if tfu.is_training() else 'validation'
    logging.info(f'Number of {phase_name} examples: {t.n_examples:,}')

    n_total_steps = None
    if n_epochs is not None:
        batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
        n_total_steps = (len(examples) * n_epochs) // batch_size

    if rng is None:
        rng = np.random.RandomState()

    if tfu.is_training() and FLAGS.train_mixed:
        dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d)
        examples2d = [
            *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID]
        ]
        build_mixed_batch(t,
                          dataset3d,
                          dataset2d,
                          examples,
                          examples2d,
                          learning_phase,
                          batch_size3d=FLAGS.batch_size,
                          batch_size2d=FLAGS.batch_size_2d,
                          shuffle=shuffle,
                          rng=rng,
                          max_unconsumed=FLAGS.max_unconsumed,
                          n_done_steps=n_done_steps,
                          n_total_steps=n_total_steps)
    else:
        batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
        helpers.build_input_batch(t,
                                  examples,
                                  data.data_loading.load_and_transform3d,
                                  (dataset3d.joint_info, learning_phase),
                                  learning_phase,
                                  batch_size,
                                  FLAGS.workers,
                                  shuffle=shuffle,
                                  drop_remainder=drop_remainder,
                                  rng=rng,
                                  max_unconsumed=FLAGS.max_unconsumed,
                                  n_done_steps=n_done_steps,
                                  n_total_steps=n_total_steps)
        (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics,
         t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask,
         t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch

    if FLAGS.scale_recovery == 'metrabs':
        model.metrabs.build_metrabs_model(dataset3d.joint_info, t)
    elif FLAGS.scale_recovery == 'metro':
        model.metro.build_metro_model(dataset3d.joint_info, t)
    else:
        model.twofive.build_25d_model(dataset3d.joint_info, t)

    if 'coords3d_true' in t:
        build_eval_metrics(t)

    if learning_phase == TRAIN:
        build_train_op(t)
        build_summaries(t)
    return t
예제 #3
0
def build_graph(learning_phase,
                reuse=tf.AUTO_REUSE,
                n_epochs=None,
                shuffle=None,
                drop_remainder=None,
                rng=None,
                n_completed_steps=0):
    tfu.set_is_training(learning_phase == TRAIN)

    dataset3d = data.datasets.current_dataset()
    if FLAGS.train_mixed:
        dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d)
    else:
        dataset2d = None

    t = AttrDict()
    t.global_step = tf.train.get_or_create_global_step()

    examples = helpers.get_examples(dataset3d, learning_phase, FLAGS)
    phase_name = tfu.PHASE_NAME[learning_phase]
    t.n_examples = len(examples)
    logging.info(f'Number of {phase_name} examples: {t.n_examples:,}')

    if n_epochs is None:
        n_total_steps = None
    else:
        batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
        n_total_steps = (len(examples) * n_epochs) // batch_size

    if rng is None:
        rng = np.random.RandomState()

    rng_2d = util.new_rng(rng)
    rng_3d = util.new_rng(rng)

    @contextlib.contextmanager
    def empty_context():
        yield

    name_scope = tf.name_scope(
        None, 'training') if learning_phase == TRAIN else empty_context()
    with name_scope:
        if learning_phase == TRAIN and FLAGS.train_mixed:
            examples2d = [
                *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID]
            ]
            build_mixed_batch(t,
                              dataset3d,
                              dataset2d,
                              examples,
                              examples2d,
                              learning_phase,
                              batch_size3d=FLAGS.batch_size,
                              batch_size2d=FLAGS.batch_size_2d,
                              shuffle=shuffle,
                              rng_2d=rng_2d,
                              rng_3d=rng_3d,
                              max_unconsumed=FLAGS.max_unconsumed,
                              n_completed_steps=n_completed_steps,
                              n_total_steps=n_total_steps)
        elif FLAGS.multiepoch_test:
            batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
            helpers.build_input_batch(t,
                                      examples,
                                      data.data_loading.load_and_transform3d,
                                      (dataset3d.joint_info, learning_phase),
                                      learning_phase,
                                      batch_size,
                                      FLAGS.workers,
                                      shuffle=shuffle,
                                      drop_remainder=drop_remainder,
                                      rng=rng_3d,
                                      max_unconsumed=FLAGS.max_unconsumed,
                                      n_completed_steps=n_completed_steps,
                                      n_total_steps=n_total_steps,
                                      n_test_epochs=FLAGS.epochs)
            (t.image_path, t.x, t.coords3d_true, t.coords2d_true,
             t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc,
             t.joint_validity_mask, t.is_joint_in_fov, t.activity_name,
             t.scene_name) = t.batch
        else:
            batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
            helpers.build_input_batch(t,
                                      examples,
                                      data.data_loading.load_and_transform3d,
                                      (dataset3d.joint_info, learning_phase),
                                      learning_phase,
                                      batch_size,
                                      FLAGS.workers,
                                      shuffle=shuffle,
                                      drop_remainder=drop_remainder,
                                      rng=rng_3d,
                                      max_unconsumed=FLAGS.max_unconsumed,
                                      n_completed_steps=n_completed_steps,
                                      n_total_steps=n_total_steps)
            (t.image_path, t.x, t.coords3d_true, t.coords2d_true,
             t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc,
             t.joint_validity_mask, t.is_joint_in_fov, t.activity_name,
             t.scene_name) = t.batch

        if FLAGS.scale_recovery == 'metro':
            model.volumetric.build_metro_model(dataset3d.joint_info,
                                               learning_phase,
                                               t,
                                               reuse=reuse)
        else:
            model.volumetric.build_25d_model(dataset3d.joint_info,
                                             learning_phase,
                                             t,
                                             reuse=reuse)

        if 'coords3d_true' in t:
            build_eval_metrics(t)

        if learning_phase == TRAIN:
            build_train_op(t)
            build_summaries(t)
        return t