def build_mixed_batch(t, dataset3d, dataset2d, examples3d, examples2d, learning_phase, batch_size3d=None, batch_size2d=None, shuffle=None, rng=None, max_unconsumed=256, n_done_steps=0, n_total_steps=None): if shuffle is None: shuffle = learning_phase == TRAIN if rng is None: rng = np.random.RandomState() rng_2d = util.new_rng(rng) rng_3d = util.new_rng(rng) (t.image_path_2d, t.x_2d, t.coords2d_true2d, t.joint_validity_mask2d) = helpers.build_input_batch( t, examples2d, data.data_loading.load_and_transform2d, (dataset2d.joint_info, learning_phase), learning_phase, batch_size2d, FLAGS.workers, shuffle=shuffle, rng=rng_2d, max_unconsumed=max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = helpers.build_input_batch( t, examples3d, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size3d, FLAGS.workers, shuffle=shuffle, rng=rng_3d, max_unconsumed=max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps)
def build_graph(learning_phase, n_epochs=None, shuffle=None, drop_remainder=None, rng=None, n_done_steps=0): tfu.set_is_training(learning_phase == TRAIN) t = AttrDict(global_step=tf.compat.v1.train.get_or_create_global_step()) dataset3d = data.datasets3d.get_dataset(FLAGS.dataset) examples = helpers.get_examples(dataset3d, learning_phase, FLAGS) t.n_examples = len(examples) phase_name = 'training' if tfu.is_training() else 'validation' logging.info(f'Number of {phase_name} examples: {t.n_examples:,}') n_total_steps = None if n_epochs is not None: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test n_total_steps = (len(examples) * n_epochs) // batch_size if rng is None: rng = np.random.RandomState() if tfu.is_training() and FLAGS.train_mixed: dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d) examples2d = [ *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID] ] build_mixed_batch(t, dataset3d, dataset2d, examples, examples2d, learning_phase, batch_size3d=FLAGS.batch_size, batch_size2d=FLAGS.batch_size_2d, shuffle=shuffle, rng=rng, max_unconsumed=FLAGS.max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps) else: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test helpers.build_input_batch(t, examples, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size, FLAGS.workers, shuffle=shuffle, drop_remainder=drop_remainder, rng=rng, max_unconsumed=FLAGS.max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch if FLAGS.scale_recovery == 'metrabs': model.metrabs.build_metrabs_model(dataset3d.joint_info, t) elif FLAGS.scale_recovery == 'metro': model.metro.build_metro_model(dataset3d.joint_info, t) else: model.twofive.build_25d_model(dataset3d.joint_info, t) if 'coords3d_true' in t: build_eval_metrics(t) if learning_phase == TRAIN: build_train_op(t) build_summaries(t) return t
def build_graph(learning_phase, reuse=tf.AUTO_REUSE, n_epochs=None, shuffle=None, drop_remainder=None, rng=None, n_completed_steps=0): tfu.set_is_training(learning_phase == TRAIN) dataset3d = data.datasets.current_dataset() if FLAGS.train_mixed: dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d) else: dataset2d = None t = AttrDict() t.global_step = tf.train.get_or_create_global_step() examples = helpers.get_examples(dataset3d, learning_phase, FLAGS) phase_name = tfu.PHASE_NAME[learning_phase] t.n_examples = len(examples) logging.info(f'Number of {phase_name} examples: {t.n_examples:,}') if n_epochs is None: n_total_steps = None else: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test n_total_steps = (len(examples) * n_epochs) // batch_size if rng is None: rng = np.random.RandomState() rng_2d = util.new_rng(rng) rng_3d = util.new_rng(rng) @contextlib.contextmanager def empty_context(): yield name_scope = tf.name_scope( None, 'training') if learning_phase == TRAIN else empty_context() with name_scope: if learning_phase == TRAIN and FLAGS.train_mixed: examples2d = [ *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID] ] build_mixed_batch(t, dataset3d, dataset2d, examples, examples2d, learning_phase, batch_size3d=FLAGS.batch_size, batch_size2d=FLAGS.batch_size_2d, shuffle=shuffle, rng_2d=rng_2d, rng_3d=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps) elif FLAGS.multiepoch_test: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test helpers.build_input_batch(t, examples, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size, FLAGS.workers, shuffle=shuffle, drop_remainder=drop_remainder, rng=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps, n_test_epochs=FLAGS.epochs) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch else: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test helpers.build_input_batch(t, examples, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size, FLAGS.workers, shuffle=shuffle, drop_remainder=drop_remainder, rng=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch if FLAGS.scale_recovery == 'metro': model.volumetric.build_metro_model(dataset3d.joint_info, learning_phase, t, reuse=reuse) else: model.volumetric.build_25d_model(dataset3d.joint_info, learning_phase, t, reuse=reuse) if 'coords3d_true' in t: build_eval_metrics(t) if learning_phase == TRAIN: build_train_op(t) build_summaries(t) return t