Exemplo n.º 1
0
def augment_appearance(im, learning_phase, rng):
    occlusion_rng = util.new_rng(rng)
    color_rng = util.new_rng(rng)

    if learning_phase == TRAIN or FLAGS.test_aug:
        if FLAGS.occlude_aug_prob > 0:
            occlude_type = str(
                occlusion_rng.choice(['objects', 'random-erase']))
        else:
            occlude_type = None

        if occlude_type == 'objects':
            # For object occlusion augmentation, do the occlusion first, then the filtering,
            # so that the occluder blends into the image better.
            if occlusion_rng.uniform(0.0, 1.0) < FLAGS.occlude_aug_prob:
                im = object_occlude(im, occlusion_rng, inplace=True)
            if FLAGS.color_aug:
                im = augmentation.color.augment_color(im, color_rng)
        elif occlude_type == 'random-erase':
            # For random erasing, do color aug first, to keep the random block distributed
            # uniformly in 0-255, as in the Random Erasing paper
            if FLAGS.color_aug:
                im = augmentation.color.augment_color(im, color_rng)
            if occlude_type and occlusion_rng.uniform(
                    0.0, 1.0) < FLAGS.occlude_aug_prob:
                im = random_erase(im,
                                  0,
                                  1 / 3,
                                  0.3,
                                  1.0 / 0.3,
                                  occlusion_rng,
                                  inplace=True)

    return im
Exemplo n.º 2
0
def parallel_map_as_tf_dataset(
        fun, iterable, *, output_types=None, output_shapes=None, shuffle_before_each_epoch=False,
        extra_args=None, n_workers=10, rng=None, max_unconsumed=256, n_completed_items=0,
        n_total_items=None):
    """Maps `fun` to each element of `iterable` and wraps the resulting sequence as
    as a TF Dataset. Elements are processed by parallel workers using mp.

    Args:
        fun: A function that takes an element from seq plus `extra_args` and returns a sequence of
        numpy arrays.
        seq: An iterable holding the inputs.
        output_types: A list of types, describing each output numpy array from `fun`.
            If None, then it is automatically determined by calling `fun` on the first element.
        output_shapes: A list of array shapes, describing each output numpy array from `fun`.
            If None, then it is automatically determined by calling `fun` on the first element.
        shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
            `iterable` to a list internally.
        extra_args: extra arguments in addition to an element from `seq`,
            given to `fun` at each call
        n_workers: Number of worker processes for parallelity.
        n_epochs: Number of times to iterate over the `iterable`.

    Returns:
        tf.data.Dataset based on the arrays returned by `fun`.
    """

    extra_args = extra_args or []

    # Automatically determine the output tensor types and shapes by calling the function on
    # the first element
    iterable = list(iterable)
    first_elem = iterable[0]
    if output_types is None or output_shapes is None:
        sample_output = fun(first_elem, *extra_args, rng=np.random.RandomState(0))
        output_shapes, output_types = tfu.get_shapes_and_tf_dtypes(sample_output)

    items = util.iterate_repeatedly(iterable, shuffle_before_each_epoch, util.new_rng(rng))

    # If we are restoring from a checkpoint and have already completed some
    # training steps for that checkpoint, then we need to advance the RNG
    # accordingly, to continue exactly where we left off.
    iter_rng = util.new_rng(rng)
    util.advance_rng(iter_rng, n_completed_items)
    logging.debug(f'n_total_items: {n_total_items}, n_completed_items: {n_completed_items}')
    items = itertools.islice(items, n_completed_items, n_total_items)

    if n_workers == 1:
        def gen():
            for item in items:
                logging.debug('yielding')
                yield fun(item, *extra_args, util.new_rng(iter_rng))
            logging.debug('ended')
    else:
        pool = tfu.get_pool(n_workers)
        gen = parallel_map_as_generator(
            fun, items, extra_args, pool, rng=iter_rng, max_unconsumed=max_unconsumed)

    return tf.data.Dataset.from_generator(gen, output_types, output_shapes)
Exemplo n.º 3
0
def build_mixed_batch(t,
                      dataset3d,
                      dataset2d,
                      examples3d,
                      examples2d,
                      learning_phase,
                      batch_size3d=None,
                      batch_size2d=None,
                      shuffle=None,
                      rng=None,
                      max_unconsumed=256,
                      n_done_steps=0,
                      n_total_steps=None):
    if shuffle is None:
        shuffle = learning_phase == TRAIN

    if rng is None:
        rng = np.random.RandomState()
    rng_2d = util.new_rng(rng)
    rng_3d = util.new_rng(rng)

    (t.image_path_2d, t.x_2d, t.coords2d_true2d,
     t.joint_validity_mask2d) = helpers.build_input_batch(
         t,
         examples2d,
         data.data_loading.load_and_transform2d,
         (dataset2d.joint_info, learning_phase),
         learning_phase,
         batch_size2d,
         FLAGS.workers,
         shuffle=shuffle,
         rng=rng_2d,
         max_unconsumed=max_unconsumed,
         n_done_steps=n_done_steps,
         n_total_steps=n_total_steps)

    (t.image_path, t.x, t.coords3d_true, t.coords2d_true, t.inv_intrinsics,
     t.rot_to_orig_cam, t.rot_to_world, t.cam_loc, t.joint_validity_mask,
     t.is_joint_in_fov, t.activity_name,
     t.scene_name) = helpers.build_input_batch(
         t,
         examples3d,
         data.data_loading.load_and_transform3d,
         (dataset3d.joint_info, learning_phase),
         learning_phase,
         batch_size3d,
         FLAGS.workers,
         shuffle=shuffle,
         rng=rng_3d,
         max_unconsumed=max_unconsumed,
         n_done_steps=n_done_steps,
         n_total_steps=n_total_steps)
Exemplo n.º 4
0
    def producer():
        for i_item, item in enumerate(items):
            semaphore.acquire()
            if _must_stop:
                return
            q.put(pool.apply_async(fun, (item, *extra_args, util.new_rng(rng))))

        logging.debug('Putting end-of-seq')
        q.put(end_of_sequence_marker)
Exemplo n.º 5
0
    def producer():
        for i_item, item in enumerate(items):
            if should_stop:
                break
            semaphore.acquire()
            q.put(pool.apply_async(fun,
                                   (item, *extra_args, util.new_rng(rng))))

        logger.debug('Putting end-of-seq')
        q.put(end_of_sequence_marker)
Exemplo n.º 6
0
def train():
    logging.info('Training phase.')
    rng = np.random.RandomState(FLAGS.seed)
    n_completed_steps = get_number_of_already_completed_steps(FLAGS.logdir)

    t_train = build_graph(TRAIN,
                          rng=util.new_rng(rng),
                          n_epochs=FLAGS.epochs,
                          n_completed_steps=n_completed_steps)
    logging.info(
        f'Number of trainable parameters: {tfu.count_trainable_params():,}')
    t_valid = (build_graph(VALID, shuffle=True, rng=util.new_rng(rng))
               if FLAGS.validate_period else None)

    helpers.run_train_loop(t_train.train_op,
                           checkpoint_dir=FLAGS.checkpoint_dir,
                           load_path=FLAGS.load_path,
                           hooks=make_training_hooks(t_train, t_valid),
                           init_fn=get_init_fn())
    logging.info('Ended training.')
Exemplo n.º 7
0
 def gen():
     for item in items:
         logging.debug('yielding')
         yield fun(item, *extra_args, util.new_rng(iter_rng))
     logging.debug('ended')
Exemplo n.º 8
0
def load_and_transform3d(ex, joint_info, learning_phase, rng=None):
    appearance_rng = util.new_rng(rng)
    background_rng = util.new_rng(rng)
    geom_rng = util.new_rng(rng)
    partial_visi_rng = util.new_rng(rng)

    output_side = FLAGS.proc_side
    output_imshape = (output_side, output_side)

    box = ex.bbox
    if FLAGS.partial_visibility:
        box = util.random_partial_subbox(boxlib.expand_to_square(box), partial_visi_rng)

    crop_side = np.max(box[2:])
    center_point = boxlib.center(box)
    if ((learning_phase == TRAIN and FLAGS.geom_aug) or
            (learning_phase != TRAIN and FLAGS.test_aug and FLAGS.geom_aug)):
        center_point += util.random_uniform_disc(geom_rng) * FLAGS.shift_aug / 100 * crop_side

    if box[2] < box[3]:
        delta_y = np.array([0, box[3] / 2])
        sidepoints = center_point + np.stack([-delta_y, delta_y])
    else:
        delta_x = np.array([box[2] / 2, 0])
        sidepoints = center_point + np.stack([-delta_x, delta_x])

    cam = ex.camera.copy()
    cam.turn_towards(target_image_point=center_point)
    cam.undistort()
    cam.square_pixels()
    world_sidepoints = ex.camera.image_to_world(sidepoints)
    cam_sidepoints = cam.world_to_image(world_sidepoints)
    crop_side = np.linalg.norm(cam_sidepoints[0] - cam_sidepoints[1])
    cam.zoom(output_side / crop_side)
    cam.center_principal_point(output_imshape)

    if FLAGS.geom_aug and (learning_phase == TRAIN or FLAGS.test_aug):
        s1 = FLAGS.scale_aug_down / 100
        s2 = FLAGS.scale_aug_up / 100
        r = FLAGS.rot_aug * np.pi / 180
        zoom = geom_rng.uniform(1 - s1, 1 + s2)
        cam.zoom(zoom)
        cam.rotate(roll=geom_rng.uniform(-r, r))

    world_coords = ex.univ_coords if FLAGS.universal_skeleton else ex.world_coords
    metric_world_coords = ex.world_coords

    if learning_phase == TRAIN and geom_rng.rand() < 0.5:
        cam.horizontal_flip()
        camcoords = cam.world_to_camera(world_coords)[joint_info.mirror_mapping]
        metric_world_coords = metric_world_coords[joint_info.mirror_mapping]
    else:
        camcoords = cam.world_to_camera(world_coords)

    imcoords = cam.world_to_image(metric_world_coords)

    image_path = util.ensure_absolute_path(ex.image_path)
    origsize_im = improc.imread_jpeg(image_path)

    interp_str = (FLAGS.image_interpolation_train
                  if learning_phase == TRAIN else FLAGS.image_interpolation_test)
    antialias = (FLAGS.antialias_train if learning_phase == TRAIN else FLAGS.antialias_test)
    interp = getattr(cv2, 'INTER_' + interp_str.upper())
    im = cameralib.reproject_image(
        origsize_im, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp)

    if re.match('.+/mupots/TS[1-5]/.+', ex.image_path):
        im = improc.adjust_gamma(im, 0.67, inplace=True)
    elif '3dhp' in ex.image_path and re.match('.+/(TS[1-4])/', ex.image_path):
        im = improc.adjust_gamma(im, 0.67, inplace=True)
        im = improc.white_balance(im, 110, 145)

    if (FLAGS.background_aug_prob and hasattr(ex, 'mask') and ex.mask is not None and
            background_rng.rand() < FLAGS.background_aug_prob and
            (learning_phase == TRAIN or FLAGS.test_aug)):
        fgmask = improc.decode_mask(ex.mask)
        fgmask = cameralib.reproject_image(
            fgmask, ex.camera, cam, output_imshape, antialias_factor=antialias, interp=interp)
        im = augmentation.background.augment_background(im, fgmask, background_rng)

    im = augmentation.appearance.augment_appearance(im, learning_phase, appearance_rng)
    im = tfu.nhwc_to_std(im)
    im = improc.normalize01(im)

    # Joints with NaN coordinates are invalid
    is_joint_in_fov = ~np.logical_or(np.any(imcoords < 0, axis=-1),
                                     np.any(imcoords >= FLAGS.proc_side, axis=-1))
    joint_validity_mask = ~np.any(np.isnan(camcoords), axis=-1)

    rot_to_orig_cam = ex.camera.R @ cam.R.T
    rot_to_world = cam.R.T
    inv_intrinsics = np.linalg.inv(cam.intrinsic_matrix)

    return (
        ex.image_path, im, np.nan_to_num(camcoords).astype(np.float32),
        np.nan_to_num(imcoords).astype(np.float32), inv_intrinsics.astype(np.float32),
        rot_to_orig_cam.astype(np.float32), rot_to_world.astype(np.float32),
        cam.t.astype(np.float32), joint_validity_mask,
        np.float32(is_joint_in_fov), ex.activity_name, ex.scene_name)
Exemplo n.º 9
0
def load_and_transform2d(example, joint_info, learning_phase, rng):
    # Get the random number generators for the different augmentations to make it reproducibile
    appearance_rng = util.new_rng(rng)
    geom_rng = util.new_rng(rng)
    partial_visi_rng = util.new_rng(rng)

    # Load the image
    image_path = util.ensure_absolute_path(example.image_path)
    im_from_file = improc.imread_jpeg(image_path)

    # Determine bounding box
    bbox = example.bbox
    if FLAGS.partial_visibility:
        bbox = util.random_partial_subbox(boxlib.expand_to_square(bbox), partial_visi_rng)

    crop_side = np.max(bbox)
    center_point = boxlib.center(bbox)
    orig_cam = cameralib.Camera.create2D(im_from_file.shape)
    cam = orig_cam.copy()
    cam.zoom(FLAGS.proc_side / crop_side)

    if FLAGS.geom_aug:
        center_point += util.random_uniform_disc(geom_rng) * FLAGS.shift_aug / 100 * crop_side
        s1 = FLAGS.scale_aug_down / 100
        s2 = FLAGS.scale_aug_up / 100
        cam.zoom(geom_rng.uniform(1 - s1, 1 + s2))
        r = FLAGS.rot_aug * np.pi / 180
        cam.rotate(roll=geom_rng.uniform(-r, r))

    if FLAGS.geom_aug and geom_rng.rand() < 0.5:
        # Horizontal flipping
        cam.horizontal_flip()
        # Must also permute the joints to exchange e.g. left wrist and right wrist!
        imcoords = example.coords[joint_info.mirror_mapping]
    else:
        imcoords = example.coords

    new_center_point = cameralib.reproject_image_points(center_point, orig_cam, cam)
    cam.shift_to_center(new_center_point, (FLAGS.proc_side, FLAGS.proc_side))

    is_annotation_invalid = (np.nan_to_num(imcoords[:, 1]) > im_from_file.shape[0] * 0.95)
    imcoords[is_annotation_invalid] = np.nan
    imcoords = cameralib.reproject_image_points(imcoords, orig_cam, cam)

    interp_str = (FLAGS.image_interpolation_train
                  if learning_phase == TRAIN else FLAGS.image_interpolation_test)
    antialias = (FLAGS.antialias_train if learning_phase == TRAIN else FLAGS.antialias_test)
    interp = getattr(cv2, 'INTER_' + interp_str.upper())
    im = cameralib.reproject_image(
        im_from_file, orig_cam, cam, (FLAGS.proc_side, FLAGS.proc_side),
        antialias_factor=antialias, interp=interp)
    im = augmentation.appearance.augment_appearance(im, learning_phase, appearance_rng)
    im = tfu.nhwc_to_std(im)
    im = improc.normalize01(im)

    joint_validity_mask = ~np.any(np.isnan(imcoords), axis=1)
    # We must eliminate NaNs because some TensorFlow ops can't deal with any NaNs touching them,
    # even if they would not influence the result. Therefore we use a separate "joint_validity_mask"
    # to indicate which joint coords are valid.
    imcoords = np.nan_to_num(imcoords)
    return example.image_path, np.float32(im), np.float32(imcoords), joint_validity_mask
Exemplo n.º 10
0
def load_and_transform3d(ex, joint_info, learning_phase, rng):
    # Get the random number generators for the different augmentations to make it reproducibile
    appearance_rng = util.new_rng(rng)
    background_rng = util.new_rng(rng)
    geom_rng = util.new_rng(rng)
    partial_visi_rng = util.new_rng(rng)

    output_side = FLAGS.proc_side
    output_imshape = (output_side, output_side)

    if 'sailvos' in ex.image_path.lower():
        # This is needed in order not to lose precision in later operations.
        # Background: In the Sailvos dataset (GTA V), some world coordinates
        # are crazy large (several kilometers, i.e. millions of millimeters, which becomes
        # hard to process with the limited simultaneous dynamic range of float32).
        # They are stored in float64 but the processing is done in float32 here.
        ex.world_coords -= ex.camera.t
        ex.camera.t[:] = 0

    box = ex.bbox
    if 'surreal' in ex.image_path.lower():
        # Surreal images are flipped wrong in the official dataset release
        box = box.copy()
        box[0] = 320 - (box[0] + box[2])

    # Partial visibility
    if 'surreal' in ex.image_path.lower() and 'surmuco' not in FLAGS.dataset:
        partial_visi_prob = 0.5
    elif 'h36m' in ex.image_path.lower() and 'many' in FLAGS.dataset:
        partial_visi_prob = 0.5
    else:
        partial_visi_prob = FLAGS.partial_visibility_prob

    use_partial_visi_aug = ((learning_phase == TRAIN or FLAGS.test_aug)
                            and partial_visi_rng.rand() < partial_visi_prob)
    if use_partial_visi_aug:
        box = util.random_partial_subbox(boxlib.expand_to_square(box),
                                         partial_visi_rng)

    # Geometric transformation and augmentation
    crop_side = np.max(box[2:])
    center_point = boxlib.center(box)
    if ((learning_phase == TRAIN and FLAGS.geom_aug) or
        (learning_phase != TRAIN and FLAGS.test_aug and FLAGS.geom_aug)):
        center_point += util.random_uniform_disc(
            geom_rng) * FLAGS.shift_aug / 100 * crop_side

    # The homographic reprojection of a rectangle (bounding box) will not be another rectangle
    # Hence, instead we transform the side midpoints of the short sides of the box and
    # determine an appropriate zoom factor by taking the projected distance of these two points
    # and scaling that to the desired output image side length.
    if box[2] < box[3]:
        # Tall box: take midpoints of top and bottom sides
        delta_y = np.array([0, box[3] / 2])
        sidepoints = center_point + np.stack([-delta_y, delta_y])
    else:
        # Wide box: take midpoints of left and right sides
        delta_x = np.array([box[2] / 2, 0])
        sidepoints = center_point + np.stack([-delta_x, delta_x])

    cam = ex.camera.copy()
    cam.turn_towards(target_image_point=center_point)
    cam.undistort()
    cam.square_pixels()
    cam_sidepoints = cameralib.reproject_image_points(sidepoints, ex.camera,
                                                      cam)
    crop_side = np.linalg.norm(cam_sidepoints[0] - cam_sidepoints[1])
    cam.zoom(output_side / crop_side)
    cam.center_principal_point(output_imshape)

    if FLAGS.geom_aug and (learning_phase == TRAIN or FLAGS.test_aug):
        s1 = FLAGS.scale_aug_down / 100
        s2 = FLAGS.scale_aug_up / 100
        zoom = geom_rng.uniform(1 - s1, 1 + s2)
        cam.zoom(zoom)
        r = np.deg2rad(FLAGS.rot_aug)
        cam.rotate(roll=geom_rng.uniform(-r, r))

    world_coords = ex.univ_coords if FLAGS.universal_skeleton else ex.world_coords
    metric_world_coords = ex.world_coords

    if learning_phase == TRAIN and geom_rng.rand() < 0.5:
        cam.horizontal_flip()
        # Must reorder the joints due to left and right flip
        camcoords = cam.world_to_camera(world_coords)[
            joint_info.mirror_mapping]
        metric_world_coords = metric_world_coords[joint_info.mirror_mapping]
    else:
        camcoords = cam.world_to_camera(world_coords)

    imcoords = cam.world_to_image(metric_world_coords)

    # Load and reproject image
    image_path = util.ensure_absolute_path(ex.image_path)
    origsize_im = improc.imread_jpeg(image_path)
    if 'surreal' in ex.image_path.lower():
        # Surreal images are flipped wrong in the official dataset release
        origsize_im = origsize_im[:, ::-1]

    interp_str = (FLAGS.image_interpolation_train if learning_phase == TRAIN
                  else FLAGS.image_interpolation_test)
    antialias = (FLAGS.antialias_train
                 if learning_phase == TRAIN else FLAGS.antialias_test)
    interp = getattr(cv2, 'INTER_' + interp_str.upper())
    im = cameralib.reproject_image(origsize_im,
                                   ex.camera,
                                   cam,
                                   output_imshape,
                                   antialias_factor=antialias,
                                   interp=interp)

    # Color adjustment
    if re.match('.*mupots/TS[1-5]/.+', ex.image_path):
        im = improc.adjust_gamma(im, 0.67, inplace=True)
    elif '3dhp' in ex.image_path and re.match('.+/(TS[1-4])/', ex.image_path):
        im = improc.adjust_gamma(im, 0.67, inplace=True)
        im = improc.white_balance(im, 110, 145)
    elif 'panoptic' in ex.image_path.lower():
        im = improc.white_balance(im, 120, 138)

    # Background augmentation
    if hasattr(ex, 'mask') and ex.mask is not None:
        bg_aug_prob = 0.2 if 'sailvos' in ex.image_path.lower(
        ) else FLAGS.background_aug_prob
        if (FLAGS.background_aug_prob
                and (learning_phase == TRAIN or FLAGS.test_aug)
                and background_rng.rand() < bg_aug_prob):
            fgmask = improc.decode_mask(ex.mask)
            if 'surreal' in ex.image_path:
                # Surreal images are flipped wrong in the official dataset release
                fgmask = fgmask[:, ::-1]
            fgmask = cameralib.reproject_image(fgmask,
                                               ex.camera,
                                               cam,
                                               output_imshape,
                                               antialias_factor=antialias,
                                               interp=interp)
            im = augmentation.background.augment_background(
                im, fgmask, background_rng)

    # Occlusion and color augmentation
    im = augmentation.appearance.augment_appearance(im, learning_phase,
                                                    FLAGS.occlude_aug_prob,
                                                    appearance_rng)
    im = tfu.nhwc_to_std(im)
    im = improc.normalize01(im)

    # Joints with NaN coordinates are invalid
    is_joint_in_fov = ~np.logical_or(
        np.any(imcoords < 0, axis=-1),
        np.any(imcoords >= FLAGS.proc_side, axis=-1))
    joint_validity_mask = ~np.any(np.isnan(camcoords), axis=-1)

    rot_to_orig_cam = ex.camera.R @ cam.R.T
    rot_to_world = cam.R.T

    return dict(image=im,
                intrinsics=np.float32(cam.intrinsic_matrix),
                image_path=ex.image_path,
                coords3d_true=np.nan_to_num(camcoords).astype(np.float32),
                coords2d_true=np.nan_to_num(imcoords).astype(np.float32),
                rot_to_orig_cam=rot_to_orig_cam.astype(np.float32),
                rot_to_world=rot_to_world.astype(np.float32),
                cam_loc=cam.t.astype(np.float32),
                joint_validity_mask=joint_validity_mask,
                is_joint_in_fov=np.float32(is_joint_in_fov))
Exemplo n.º 11
0
def roundrobin_iterate_repeatedly(
        seqs, roundrobin_sizes, shuffle_before_each_epoch=False, rng=None):
    iters = [iterate_repeatedly(seq, shuffle_before_each_epoch, util.new_rng(rng)) for seq in seqs]
    return roundrobin(iters, roundrobin_sizes)
Exemplo n.º 12
0
 def gen():
     for item in items:
         yield fun(item, *extra_args, util.new_rng(iter_rng))
Exemplo n.º 13
0
def parallel_map_as_tf_dataset(fun,
                               iterable,
                               *,
                               shuffle_before_each_epoch=False,
                               extra_args=None,
                               n_workers=10,
                               rng=None,
                               max_unconsumed=256,
                               n_completed_items=0,
                               n_total_items=None,
                               roundrobin_sizes=None):
    """Maps `fun` to each element of `iterable` and wraps the resulting sequence as
    as a TensorFlow Dataset. Elements are processed by parallel workers using `multiprocessing`.

    Args:
        fun: A function that takes an element from seq plus `extra_args` and returns a sequence of
        numpy arrays.
        seq: An iterable holding the inputs.
        shuffle_before_each_epoch: Shuffle the input elements before each epoch. Converts
            `iterable` to a list internally.
        extra_args: extra arguments in addition to an element from `seq`,
            given to `fun` at each call
        n_workers: Number of worker processes for parallelity.

    Returns:
        tf.data.Dataset based on the arrays returned by `fun`.
    """

    extra_args = extra_args or []

    # Automatically determine the output tensor types and shapes by calling the function on
    # the first element
    if not roundrobin_sizes:
        iterable = more_itertools.peekable(iterable)
        first_elem = iterable.peek()
    else:
        iterable[0] = more_itertools.peekable(iterable[0])
        first_elem = iterable[0].peek()

    sample_output = fun(first_elem, *extra_args, rng=np.random.RandomState(0))
    output_signature = tf.nest.map_structure(tf.type_spec_from_value,
                                             sample_output)

    if not roundrobin_sizes:
        items = my_itertools.iterate_repeatedly(iterable,
                                                shuffle_before_each_epoch,
                                                util.new_rng(rng))
    else:
        items = my_itertools.roundrobin_iterate_repeatedly(
            iterable, roundrobin_sizes, shuffle_before_each_epoch, rng)

    # If we are restoring from a checkpoint and have already completed some
    # training steps for that checkpoint, then we need to advance the RNG
    # accordingly, to continue exactly where we left off.
    iter_rng = util.new_rng(rng)
    util.advance_rng(iter_rng, n_completed_items)
    items = itertools.islice(items, n_completed_items, n_total_items)

    if n_workers is None:
        n_workers = min(len(os.sched_getaffinity(0)), 12)
    if n_workers == 0:

        def gen():
            for item in items:
                yield fun(item, *extra_args, util.new_rng(iter_rng))
    else:
        gen = parallel_map_as_generator(fun,
                                        items,
                                        extra_args,
                                        n_workers,
                                        rng=iter_rng,
                                        max_unconsumed=max_unconsumed)

    ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature)

    # Make the cardinality of the dataset known to TF.
    if n_total_items is not None:
        ds = ds.take(n_total_items - n_completed_items)
    return ds
Exemplo n.º 14
0
def build_graph(learning_phase,
                reuse=tf.AUTO_REUSE,
                n_epochs=None,
                shuffle=None,
                drop_remainder=None,
                rng=None,
                n_completed_steps=0):
    tfu.set_is_training(learning_phase == TRAIN)

    dataset3d = data.datasets.current_dataset()
    if FLAGS.train_mixed:
        dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d)
    else:
        dataset2d = None

    t = AttrDict()
    t.global_step = tf.train.get_or_create_global_step()

    examples = helpers.get_examples(dataset3d, learning_phase, FLAGS)
    phase_name = tfu.PHASE_NAME[learning_phase]
    t.n_examples = len(examples)
    logging.info(f'Number of {phase_name} examples: {t.n_examples:,}')

    if n_epochs is None:
        n_total_steps = None
    else:
        batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
        n_total_steps = (len(examples) * n_epochs) // batch_size

    if rng is None:
        rng = np.random.RandomState()

    rng_2d = util.new_rng(rng)
    rng_3d = util.new_rng(rng)

    @contextlib.contextmanager
    def empty_context():
        yield

    name_scope = tf.name_scope(
        None, 'training') if learning_phase == TRAIN else empty_context()
    with name_scope:
        if learning_phase == TRAIN and FLAGS.train_mixed:
            examples2d = [
                *dataset2d.examples[tfu.TRAIN], *dataset2d.examples[tfu.VALID]
            ]
            build_mixed_batch(t,
                              dataset3d,
                              dataset2d,
                              examples,
                              examples2d,
                              learning_phase,
                              batch_size3d=FLAGS.batch_size,
                              batch_size2d=FLAGS.batch_size_2d,
                              shuffle=shuffle,
                              rng_2d=rng_2d,
                              rng_3d=rng_3d,
                              max_unconsumed=FLAGS.max_unconsumed,
                              n_completed_steps=n_completed_steps,
                              n_total_steps=n_total_steps)
        elif FLAGS.multiepoch_test:
            batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
            helpers.build_input_batch(t,
                                      examples,
                                      data.data_loading.load_and_transform3d,
                                      (dataset3d.joint_info, learning_phase),
                                      learning_phase,
                                      batch_size,
                                      FLAGS.workers,
                                      shuffle=shuffle,
                                      drop_remainder=drop_remainder,
                                      rng=rng_3d,
                                      max_unconsumed=FLAGS.max_unconsumed,
                                      n_completed_steps=n_completed_steps,
                                      n_total_steps=n_total_steps,
                                      n_test_epochs=FLAGS.epochs)
            (t.image_path, t.x, t.coords3d_true, t.coords2d_true,
             t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc,
             t.joint_validity_mask, t.is_joint_in_fov, t.activity_name,
             t.scene_name) = t.batch
        else:
            batch_size = FLAGS.batch_size if learning_phase == TRAIN else FLAGS.batch_size_test
            helpers.build_input_batch(t,
                                      examples,
                                      data.data_loading.load_and_transform3d,
                                      (dataset3d.joint_info, learning_phase),
                                      learning_phase,
                                      batch_size,
                                      FLAGS.workers,
                                      shuffle=shuffle,
                                      drop_remainder=drop_remainder,
                                      rng=rng_3d,
                                      max_unconsumed=FLAGS.max_unconsumed,
                                      n_completed_steps=n_completed_steps,
                                      n_total_steps=n_total_steps)
            (t.image_path, t.x, t.coords3d_true, t.coords2d_true,
             t.inv_intrinsics, t.rot_to_orig_cam, t.rot_to_world, t.cam_loc,
             t.joint_validity_mask, t.is_joint_in_fov, t.activity_name,
             t.scene_name) = t.batch

        if FLAGS.scale_recovery == 'metro':
            model.volumetric.build_metro_model(dataset3d.joint_info,
                                               learning_phase,
                                               t,
                                               reuse=reuse)
        else:
            model.volumetric.build_25d_model(dataset3d.joint_info,
                                             learning_phase,
                                             t,
                                             reuse=reuse)

        if 'coords3d_true' in t:
            build_eval_metrics(t)

        if learning_phase == TRAIN:
            build_train_op(t)
            build_summaries(t)
        return t