def augment_appearance(im, learning_phase, rng): occlusion_rng = util.new_rng(rng) color_rng = util.new_rng(rng) if learning_phase == TRAIN or FLAGS.test_aug: if FLAGS.occlude_aug_prob > 0: occlude_type = str( occlusion_rng.choice(['objects', 'random-erase'])) else: occlude_type = None if occlude_type == 'objects': # For object occlusion augmentation, do the occlusion first, then the filtering, # so that the occluder blends into the image better. if occlusion_rng.uniform(0.0, 1.0) < FLAGS.occlude_aug_prob: im = object_occlude(im, occlusion_rng, inplace=True) if FLAGS.color_aug: im = augmentation.color.augment_color(im, color_rng) elif occlude_type == 'random-erase': # For random erasing, do color aug first, to keep the random block distributed # uniformly in 0-255, as in the Random Erasing paper if FLAGS.color_aug: im = augmentation.color.augment_color(im, color_rng) if occlude_type and occlusion_rng.uniform( 0.0, 1.0) < FLAGS.occlude_aug_prob: im = random_erase(im, 0, 1 / 3, 0.3, 1.0 / 0.3, occlusion_rng, inplace=True) return im
def parallel_map_as_tf_dataset( fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False, extra_args=None, n_workers=10, rng=None, max_unconsumed=256, n_completed_items=0, n_total_items=None): """Maps `fun` to each element of `iterable` and wraps the resulting sequence as as a TF Dataset. Elements are processed by parallel workers using mp. Args: fun: A function that takes an element from seq plus `extra_args` and returns a sequence of numpy arrays. seq: An iterable holding the inputs. output_types: A list of types, describing each output numpy array from `fun`. If None, then it is automatically determined by calling `fun` on the first element. output_shapes: A list of array shapes, describing each output numpy array from `fun`. If None, then it is automatically determined by calling `fun` on the first element. shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts `iterable` to a list internally. extra_args: extra arguments in addition to an element from `seq`, given to `fun` at each call n_workers: Number of worker processes for parallelity. n_epochs: Number of times to iterate over the `iterable`. Returns: tf.data.Dataset based on the arrays returned by `fun`. """ extra_args = extra_args or [] # Automatically determine the output tensor types and shapes by calling the function on # the first element iterable = list(iterable) first_elem = iterable[0] if output_types is None or output_shapes is None: sample_output = fun(first_elem, *extra_args, rng=np.random.RandomState(0)) output_shapes, output_types = tfu.get_shapes_and_tf_dtypes(sample_output) items = util.iterate_repeatedly(iterable, shuffle_before_each_epoch, util.new_rng(rng)) # If we are restoring from a checkpoint and have already completed some # training steps for that checkpoint, then we need to advance the RNG # accordingly, to continue exactly where we left off. iter_rng = util.new_rng(rng) util.advance_rng(iter_rng, n_completed_items) logging.debug(f'n_total_items: {n_total_items}, n_completed_items: {n_completed_items}') items = itertools.islice(items, n_completed_items, n_total_items) if n_workers == 1: def gen(): for item in items: logging.debug('yielding') yield fun(item, *extra_args, util.new_rng(iter_rng)) logging.debug('ended') else: pool = tfu.get_pool(n_workers) gen = parallel_map_as_generator( fun, items, extra_args, pool, rng=iter_rng, max_unconsumed=max_unconsumed) return tf.data.Dataset.from_generator(gen, output_types, output_shapes)
def build_mixed_batch(t, dataset3d, dataset2d, examples3d, examples2d, learning_phase, batch_size3d=None, batch_size2d=None, shuffle=None, rng=None, max_unconsumed=256, n_done_steps=0, n_total_steps=None): if shuffle is None: shuffle = learning_phase == TRAIN if rng is None: rng = np.random.RandomState() rng_2d = util.new_rng(rng) rng_3d = util.new_rng(rng) (t.image_path_2d, t.x_2d, t.coords2d_true2d, t.joint_validity_mask2d) = helpers.build_input_batch( t, examples2d, data.data_loading.load_and_transform2d, (dataset2d.joint_info, learning_phase), learning_phase, batch_size2d, FLAGS.workers, shuffle=shuffle, rng=rng_2d, max_unconsumed=max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = helpers.build_input_batch( t, examples3d, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size3d, FLAGS.workers, shuffle=shuffle, rng=rng_3d, max_unconsumed=max_unconsumed, n_done_steps=n_done_steps, n_total_steps=n_total_steps)
def producer(): for i_item, item in enumerate(items): semaphore.acquire() if _must_stop: return q.put(pool.apply_async(fun, (item, *extra_args, util.new_rng(rng)))) logging.debug('Putting end-of-seq') q.put(end_of_sequence_marker)
def producer(): for i_item, item in enumerate(items): if should_stop: break semaphore.acquire() q.put(pool.apply_async(fun, (item, *extra_args, util.new_rng(rng)))) logger.debug('Putting end-of-seq') q.put(end_of_sequence_marker)
def train(): logging.info('Training phase.') rng = np.random.RandomState(FLAGS.seed) n_completed_steps = get_number_of_already_completed_steps(FLAGS.logdir) t_train = build_graph(TRAIN, rng=util.new_rng(rng), n_epochs=FLAGS.epochs, n_completed_steps=n_completed_steps) logging.info( f'Number of trainable parameters: {tfu.count_trainable_params():,}') t_valid = (build_graph(VALID, shuffle=True, rng=util.new_rng(rng)) if FLAGS.validate_period else None) helpers.run_train_loop(t_train.train_op, checkpoint_dir=FLAGS.checkpoint_dir, load_path=FLAGS.load_path, hooks=make_training_hooks(t_train, t_valid), init_fn=get_init_fn()) logging.info('Ended training.')
def gen(): for item in items: logging.debug('yielding') yield fun(item, *extra_args, util.new_rng(iter_rng)) logging.debug('ended')
def load_and_transform3d(ex, joint_info, learning_phase, rng=None): appearance_rng = util.new_rng(rng) background_rng = util.new_rng(rng) geom_rng = util.new_rng(rng) partial_visi_rng = util.new_rng(rng) output_side = FLAGS.proc_side output_imshape = (output_side, output_side) box = ex.bbox if FLAGS.partial_visibility: box = util.random_partial_subbox(boxlib.expand_to_square(box), partial_visi_rng) crop_side = np.max(box[2:]) center_point = boxlib.center(box) if ((learning_phase == TRAIN and FLAGS.geom_aug) or (learning_phase != TRAIN and FLAGS.test_aug and FLAGS.geom_aug)): center_point += util.random_uniform_disc(geom_rng) * FLAGS.shift_aug / 100 * crop_side if box[2] < box[3]: delta_y = np.array([0, box[3] / 2]) sidepoints = center_point + np.stack([-delta_y, delta_y]) else: delta_x = np.array([box[2] / 2, 0]) sidepoints = center_point + np.stack([-delta_x, delta_x]) cam = ex.camera.copy() cam.turn_towards(target_image_point=center_point) cam.undistort() cam.square_pixels() world_sidepoints = ex.camera.image_to_world(sidepoints) cam_sidepoints = cam.world_to_image(world_sidepoints) crop_side = np.linalg.norm(cam_sidepoints[0] - cam_sidepoints[1]) cam.zoom(output_side / crop_side) cam.center_principal_point(output_imshape) if FLAGS.geom_aug and (learning_phase == TRAIN or FLAGS.test_aug): s1 = FLAGS.scale_aug_down / 100 s2 = FLAGS.scale_aug_up / 100 r = FLAGS.rot_aug * np.pi / 180 zoom = geom_rng.uniform(1 - s1, 1 + s2) cam.zoom(zoom) cam.rotate(roll=geom_rng.uniform(-r, r)) world_coords = ex.univ_coords if FLAGS.universal_skeleton else ex.world_coords metric_world_coords = ex.world_coords if learning_phase == TRAIN and geom_rng.rand() < 0.5: cam.horizontal_flip() camcoords = cam.world_to_camera(world_coords)[joint_info.mirror_mapping] metric_world_coords = metric_world_coords[joint_info.mirror_mapping] else: camcoords = cam.world_to_camera(world_coords) imcoords = cam.world_to_image(metric_world_coords) image_path = util.ensure_absolute_path(ex.image_path) origsize_im = improc.imread_jpeg(image_path) interp_str = (FLAGS.image_interpolation_train if learning_phase == TRAIN else FLAGS.image_interpolation_test) antialias = (FLAGS.antialias_train if learning_phase == TRAIN else FLAGS.antialias_test) interp = getattr(cv2, 'INTER_' + interp_str.upper()) im = cameralib.reproject_image( origsize_im, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp) if re.match('.+/mupots/TS[1-5]/.+', ex.image_path): im = improc.adjust_gamma(im, 0.67, inplace=True) elif '3dhp' in ex.image_path and re.match('.+/(TS[1-4])/', ex.image_path): im = improc.adjust_gamma(im, 0.67, inplace=True) im = improc.white_balance(im, 110, 145) if (FLAGS.background_aug_prob and hasattr(ex, 'mask') and ex.mask is not None and background_rng.rand() < FLAGS.background_aug_prob and (learning_phase == TRAIN or FLAGS.test_aug)): fgmask = improc.decode_mask(ex.mask) fgmask = cameralib.reproject_image( fgmask, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp) im = augmentation.background.augment_background(im, fgmask, background_rng) im = augmentation.appearance.augment_appearance(im, learning_phase, appearance_rng) im = tfu.nhwc_to_std(im) im = improc.normalize01(im) # Joints with NaN coordinates are invalid is_joint_in_fov = ~np.logical_or(np.any(imcoords < 0, axis=-1), np.any(imcoords >= FLAGS.proc_side, axis=-1)) joint_validity_mask = ~np.any(np.isnan(camcoords), axis=-1) rot_to_orig_cam = ex.camera.R @ cam.R.T rot_to_world = cam.R.T inv_intrinsics = np.linalg.inv(cam.intrinsic_matrix) return ( ex.image_path, im, np.nan_to_num(camcoords).astype(np.float32), np.nan_to_num(imcoords).astype(np.float32), inv_intrinsics.astype(np.float32), rot_to_orig_cam.astype(np.float32), rot_to_world.astype(np.float32), cam.t.astype(np.float32), joint_validity_mask, np.float32(is_joint_in_fov), ex.activity_name, ex.scene_name)
def load_and_transform2d(example, joint_info, learning_phase, rng): # Get the random number generators for the different augmentations to make it reproducibile appearance_rng = util.new_rng(rng) geom_rng = util.new_rng(rng) partial_visi_rng = util.new_rng(rng) # Load the image image_path = util.ensure_absolute_path(example.image_path) im_from_file = improc.imread_jpeg(image_path) # Determine bounding box bbox = example.bbox if FLAGS.partial_visibility: bbox = util.random_partial_subbox(boxlib.expand_to_square(bbox), partial_visi_rng) crop_side = np.max(bbox) center_point = boxlib.center(bbox) orig_cam = cameralib.Camera.create2D(im_from_file.shape) cam = orig_cam.copy() cam.zoom(FLAGS.proc_side / crop_side) if FLAGS.geom_aug: center_point += util.random_uniform_disc(geom_rng) * FLAGS.shift_aug / 100 * crop_side s1 = FLAGS.scale_aug_down / 100 s2 = FLAGS.scale_aug_up / 100 cam.zoom(geom_rng.uniform(1 - s1, 1 + s2)) r = FLAGS.rot_aug * np.pi / 180 cam.rotate(roll=geom_rng.uniform(-r, r)) if FLAGS.geom_aug and geom_rng.rand() < 0.5: # Horizontal flipping cam.horizontal_flip() # Must also permute the joints to exchange e.g. left wrist and right wrist! imcoords = example.coords[joint_info.mirror_mapping] else: imcoords = example.coords new_center_point = cameralib.reproject_image_points(center_point, orig_cam, cam) cam.shift_to_center(new_center_point, (FLAGS.proc_side, FLAGS.proc_side)) is_annotation_invalid = (np.nan_to_num(imcoords[:, 1]) > im_from_file.shape[0] * 0.95) imcoords[is_annotation_invalid] = np.nan imcoords = cameralib.reproject_image_points(imcoords, orig_cam, cam) interp_str = (FLAGS.image_interpolation_train if learning_phase == TRAIN else FLAGS.image_interpolation_test) antialias = (FLAGS.antialias_train if learning_phase == TRAIN else FLAGS.antialias_test) interp = getattr(cv2, 'INTER_' + interp_str.upper()) im = cameralib.reproject_image( im_from_file, orig_cam, cam, (FLAGS.proc_side, FLAGS.proc_side), antialias_factor=antialias, interp=interp) im = augmentation.appearance.augment_appearance(im, learning_phase, appearance_rng) im = tfu.nhwc_to_std(im) im = improc.normalize01(im) joint_validity_mask = ~np.any(np.isnan(imcoords), axis=1) # We must eliminate NaNs because some TensorFlow ops can't deal with any NaNs touching them, # even if they would not influence the result. Therefore we use a separate "joint_validity_mask" # to indicate which joint coords are valid. imcoords = np.nan_to_num(imcoords) return example.image_path, np.float32(im), np.float32(imcoords), joint_validity_mask
def load_and_transform3d(ex, joint_info, learning_phase, rng): # Get the random number generators for the different augmentations to make it reproducibile appearance_rng = util.new_rng(rng) background_rng = util.new_rng(rng) geom_rng = util.new_rng(rng) partial_visi_rng = util.new_rng(rng) output_side = FLAGS.proc_side output_imshape = (output_side, output_side) if 'sailvos' in ex.image_path.lower(): # This is needed in order not to lose precision in later operations. # Background: In the Sailvos dataset (GTA V), some world coordinates # are crazy large (several kilometers, i.e. millions of millimeters, which becomes # hard to process with the limited simultaneous dynamic range of float32). # They are stored in float64 but the processing is done in float32 here. ex.world_coords -= ex.camera.t ex.camera.t[:] = 0 box = ex.bbox if 'surreal' in ex.image_path.lower(): # Surreal images are flipped wrong in the official dataset release box = box.copy() box[0] = 320 - (box[0] + box[2]) # Partial visibility if 'surreal' in ex.image_path.lower() and 'surmuco' not in FLAGS.dataset: partial_visi_prob = 0.5 elif 'h36m' in ex.image_path.lower() and 'many' in FLAGS.dataset: partial_visi_prob = 0.5 else: partial_visi_prob = FLAGS.partial_visibility_prob use_partial_visi_aug = ((learning_phase == TRAIN or FLAGS.test_aug) and partial_visi_rng.rand() < partial_visi_prob) if use_partial_visi_aug: box = util.random_partial_subbox(boxlib.expand_to_square(box), partial_visi_rng) # Geometric transformation and augmentation crop_side = np.max(box[2:]) center_point = boxlib.center(box) if ((learning_phase == TRAIN and FLAGS.geom_aug) or (learning_phase != TRAIN and FLAGS.test_aug and FLAGS.geom_aug)): center_point += util.random_uniform_disc( geom_rng) * FLAGS.shift_aug / 100 * crop_side # The homographic reprojection of a rectangle (bounding box) will not be another rectangle # Hence, instead we transform the side midpoints of the short sides of the box and # determine an appropriate zoom factor by taking the projected distance of these two points # and scaling that to the desired output image side length. if box[2] < box[3]: # Tall box: take midpoints of top and bottom sides delta_y = np.array([0, box[3] / 2]) sidepoints = center_point + np.stack([-delta_y, delta_y]) else: # Wide box: take midpoints of left and right sides delta_x = np.array([box[2] / 2, 0]) sidepoints = center_point + np.stack([-delta_x, delta_x]) cam = ex.camera.copy() cam.turn_towards(target_image_point=center_point) cam.undistort() cam.square_pixels() cam_sidepoints = cameralib.reproject_image_points(sidepoints, ex.camera, cam) crop_side = np.linalg.norm(cam_sidepoints[0] - cam_sidepoints[1]) cam.zoom(output_side / crop_side) cam.center_principal_point(output_imshape) if FLAGS.geom_aug and (learning_phase == TRAIN or FLAGS.test_aug): s1 = FLAGS.scale_aug_down / 100 s2 = FLAGS.scale_aug_up / 100 zoom = geom_rng.uniform(1 - s1, 1 + s2) cam.zoom(zoom) r = np.deg2rad(FLAGS.rot_aug) cam.rotate(roll=geom_rng.uniform(-r, r)) world_coords = ex.univ_coords if FLAGS.universal_skeleton else ex.world_coords metric_world_coords = ex.world_coords if learning_phase == TRAIN and geom_rng.rand() < 0.5: cam.horizontal_flip() # Must reorder the joints due to left and right flip camcoords = cam.world_to_camera(world_coords)[ joint_info.mirror_mapping] metric_world_coords = metric_world_coords[joint_info.mirror_mapping] else: camcoords = cam.world_to_camera(world_coords) imcoords = cam.world_to_image(metric_world_coords) # Load and reproject image image_path = util.ensure_absolute_path(ex.image_path) origsize_im = improc.imread_jpeg(image_path) if 'surreal' in ex.image_path.lower(): # Surreal images are flipped wrong in the official dataset release origsize_im = origsize_im[:, ::-1] interp_str = (FLAGS.image_interpolation_train if learning_phase == TRAIN else FLAGS.image_interpolation_test) antialias = (FLAGS.antialias_train if learning_phase == TRAIN else FLAGS.antialias_test) interp = getattr(cv2, 'INTER_' + interp_str.upper()) im = cameralib.reproject_image(origsize_im, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp) # Color adjustment if re.match('.*mupots/TS[1-5]/.+', ex.image_path): im = improc.adjust_gamma(im, 0.67, inplace=True) elif '3dhp' in ex.image_path and re.match('.+/(TS[1-4])/', ex.image_path): im = improc.adjust_gamma(im, 0.67, inplace=True) im = improc.white_balance(im, 110, 145) elif 'panoptic' in ex.image_path.lower(): im = improc.white_balance(im, 120, 138) # Background augmentation if hasattr(ex, 'mask') and ex.mask is not None: bg_aug_prob = 0.2 if 'sailvos' in ex.image_path.lower( ) else FLAGS.background_aug_prob if (FLAGS.background_aug_prob and (learning_phase == TRAIN or FLAGS.test_aug) and background_rng.rand() < bg_aug_prob): fgmask = improc.decode_mask(ex.mask) if 'surreal' in ex.image_path: # Surreal images are flipped wrong in the official dataset release fgmask = fgmask[:, ::-1] fgmask = cameralib.reproject_image(fgmask, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp) im = augmentation.background.augment_background( im, fgmask, background_rng) # Occlusion and color augmentation im = augmentation.appearance.augment_appearance(im, learning_phase, FLAGS.occlude_aug_prob, appearance_rng) im = tfu.nhwc_to_std(im) im = improc.normalize01(im) # Joints with NaN coordinates are invalid is_joint_in_fov = ~np.logical_or( np.any(imcoords < 0, axis=-1), np.any(imcoords >= FLAGS.proc_side, axis=-1)) joint_validity_mask = ~np.any(np.isnan(camcoords), axis=-1) rot_to_orig_cam = ex.camera.R @ cam.R.T rot_to_world = cam.R.T return dict(image=im, intrinsics=np.float32(cam.intrinsic_matrix), image_path=ex.image_path, coords3d_true=np.nan_to_num(camcoords).astype(np.float32), coords2d_true=np.nan_to_num(imcoords).astype(np.float32), rot_to_orig_cam=rot_to_orig_cam.astype(np.float32), rot_to_world=rot_to_world.astype(np.float32), cam_loc=cam.t.astype(np.float32), joint_validity_mask=joint_validity_mask, is_joint_in_fov=np.float32(is_joint_in_fov))
def roundrobin_iterate_repeatedly( seqs, roundrobin_sizes, shuffle_before_each_epoch=False, rng=None): iters = [iterate_repeatedly(seq, shuffle_before_each_epoch, util.new_rng(rng)) for seq in seqs] return roundrobin(iters, roundrobin_sizes)
def gen(): for item in items: yield fun(item, *extra_args, util.new_rng(iter_rng))
def parallel_map_as_tf_dataset(fun, iterable, *, shuffle_before_each_epoch=False, extra_args=None, n_workers=10, rng=None, max_unconsumed=256, n_completed_items=0, n_total_items=None, roundrobin_sizes=None): """Maps `fun` to each element of `iterable` and wraps the resulting sequence as as a TensorFlow Dataset. Elements are processed by parallel workers using `multiprocessing`. Args: fun: A function that takes an element from seq plus `extra_args` and returns a sequence of numpy arrays. seq: An iterable holding the inputs. shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts `iterable` to a list internally. extra_args: extra arguments in addition to an element from `seq`, given to `fun` at each call n_workers: Number of worker processes for parallelity. Returns: tf.data.Dataset based on the arrays returned by `fun`. """ extra_args = extra_args or [] # Automatically determine the output tensor types and shapes by calling the function on # the first element if not roundrobin_sizes: iterable = more_itertools.peekable(iterable) first_elem = iterable.peek() else: iterable[0] = more_itertools.peekable(iterable[0]) first_elem = iterable[0].peek() sample_output = fun(first_elem, *extra_args, rng=np.random.RandomState(0)) output_signature = tf.nest.map_structure(tf.type_spec_from_value, sample_output) if not roundrobin_sizes: items = my_itertools.iterate_repeatedly(iterable, shuffle_before_each_epoch, util.new_rng(rng)) else: items = my_itertools.roundrobin_iterate_repeatedly( iterable, roundrobin_sizes, shuffle_before_each_epoch, rng) # If we are restoring from a checkpoint and have already completed some # training steps for that checkpoint, then we need to advance the RNG # accordingly, to continue exactly where we left off. iter_rng = util.new_rng(rng) util.advance_rng(iter_rng, n_completed_items) items = itertools.islice(items, n_completed_items, n_total_items) if n_workers is None: n_workers = min(len(os.sched_getaffinity(0)), 12) if n_workers == 0: def gen(): for item in items: yield fun(item, *extra_args, util.new_rng(iter_rng)) else: gen = parallel_map_as_generator(fun, items, extra_args, n_workers, rng=iter_rng, max_unconsumed=max_unconsumed) ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature) # Make the cardinality of the dataset known to TF. if n_total_items is not None: ds = ds.take(n_total_items - n_completed_items) return ds
def build_graph(learning_phase, reuse=tf.AUTO_REUSE, n_epochs=None, shuffle=None, drop_remainder=None, rng=None, n_completed_steps=0): tfu.set_is_training(learning_phase == TRAIN) dataset3d = data.datasets.current_dataset() if FLAGS.train_mixed: dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d) else: dataset2d = None t = AttrDict() t.global_step = tf.train.get_or_create_global_step() examples = helpers.get_examples(dataset3d, learning_phase, FLAGS) phase_name = tfu.PHASE_NAME[learning_phase] t.n_examples = len(examples) logging.info(f'Number of {phase_name} examples: {t.n_examples:,}') if n_epochs is None: n_total_steps = None else: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test n_total_steps = (len(examples) * n_epochs) // batch_size if rng is None: rng = np.random.RandomState() rng_2d = util.new_rng(rng) rng_3d = util.new_rng(rng) @contextlib.contextmanager def empty_context(): yield name_scope = tf.name_scope( None, 'training') if learning_phase == TRAIN else empty_context() with name_scope: if learning_phase == TRAIN and FLAGS.train_mixed: examples2d = [ *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID] ] build_mixed_batch(t, dataset3d, dataset2d, examples, examples2d, learning_phase, batch_size3d=FLAGS.batch_size, batch_size2d=FLAGS.batch_size_2d, shuffle=shuffle, rng_2d=rng_2d, rng_3d=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps) elif FLAGS.multiepoch_test: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test helpers.build_input_batch(t, examples, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size, FLAGS.workers, shuffle=shuffle, drop_remainder=drop_remainder, rng=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps, n_test_epochs=FLAGS.epochs) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch else: batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test helpers.build_input_batch(t, examples, data.data_loading.load_and_transform3d, (dataset3d.joint_info, learning_phase), learning_phase, batch_size, FLAGS.workers, shuffle=shuffle, drop_remainder=drop_remainder, rng=rng_3d, max_unconsumed=FLAGS.max_unconsumed, n_completed_steps=n_completed_steps, n_total_steps=n_total_steps) (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask, t.is_joint_in_fov, t.activity_name, t.scene_name) = t.batch if FLAGS.scale_recovery == 'metro': model.volumetric.build_metro_model(dataset3d.joint_info, learning_phase, t, reuse=reuse) else: model.volumetric.build_25d_model(dataset3d.joint_info, learning_phase, t, reuse=reuse) if 'coords3d_true' in t: build_eval_metrics(t) if learning_phase == TRAIN: build_train_op(t) build_summaries(t) return t