Пример #1
0
 def __init__(self,
              metro_model,
              joint_info,
              joint_info2d=None,
              global_step=None):
     super().__init__(global_step)
     self.global_step = global_step
     self.joint_info = joint_info
     self.joint_info_2d = joint_info2d
     self.model = metro_model
     inp = keras.Input(shape=(None, None, 3), dtype=tfu.get_dtype())
     self(inp, training=False)
Пример #2
0
def resnet(inp,
           n_outs,
           stride=16,
           centered_stride=False,
           global_pool=False,
           resnet_name='resnet_v2_50'):
    with slim.arg_scope(resnet_arg_scope()):
        if FLAGS.compatibility_mode:
            x = tf.cast(inp, tfu.get_dtype())
        else:
            x = tf.cast(inp * 2 - 1, tfu.get_dtype())

        resnet_fn = getattr(model.resnet_v2, resnet_name)
        xs, end_points = resnet_fn(x,
                                   num_classes=n_outs,
                                   is_training=tfu.is_training(),
                                   global_pool=global_pool,
                                   output_stride=stride,
                                   centered_stride=centered_stride)
        xs = [tf.cast(x, tf.float32) for x in xs]
        return xs
Пример #3
0
    def __init__(self, crop_model, joint_info):
        super().__init__()
        self.crop_model = crop_model
        self.predict_crop = self.crop_model.signatures['serving_default']
        self.joint_names = tf.Variable(np.array(joint_info.names),
                                       trainable=False)
        self.joint_edges = tf.Variable(np.array(joint_info.stick_figure_edges),
                                       trainable=False)

        @tf.function(input_signature=[
            tf.TensorSpec(shape=(None, None, None, 3), dtype=tfu.get_dtype()),
            tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32)
        ])
        def __call__(image, intrinsic_matrix):
            return self.predict_crop(image=image,
                                     intrinsics=intrinsic_matrix)['poses']

        self.__call__ = __call__
Пример #4
0
def resnet(inp,
           n_out,
           stride=16,
           centered_stride=False,
           resnet_name='resnet_v2_50'):
    # if resnet_name == 'mobilenet':
    #     return mobilenet(inp, n_out, stride)
    with slim.arg_scope(resnet_arg_scope()):
        x = tf.cast(inp, tfu.get_dtype())
        resnet_fn = getattr(model.resnet_v2, resnet_name)
        x, end_points = resnet_fn(x,
                                  num_classes=n_out,
                                  is_training=tfu.is_training(),
                                  global_pool=False,
                                  output_stride=stride,
                                  centered_stride=centered_stride)
        x = tf.cast(x, tf.float32)
        return x
Пример #5
0
def mobilenet_preproc(x):
    return tf.cast(255, tfu.get_dtype()) * x
Пример #6
0
def tf_preproc(x):
    x = tf.cast(2, tfu.get_dtype()) * x - tf.cast(1, tfu.get_dtype())
    return x
Пример #7
0
def caffe_preproc(x):
    mean_rgb = tf.convert_to_tensor(np.array([103.939, 116.779, 123.68]), tfu.get_dtype())
    return tf.cast(255, tfu.get_dtype()) * x - mean_rgb
Пример #8
0
def torch_preproc(x):
    mean_rgb = tf.convert_to_tensor(np.array([0.485, 0.456, 0.406]), tfu.get_dtype())
    stdev_rgb = tf.convert_to_tensor(np.array([0.229, 0.224, 0.225]), tfu.get_dtype())
    normalized = (x - mean_rgb) / stdev_rgb
    return normalized
Пример #9
0
def export():
    logging.info('Exporting model file.')
    tf.compat.v1.reset_default_graph()

    t = attrdict.AttrDict()
    t.x = tf.compat.v1.placeholder(
        shape=[None, FLAGS.proc_side, FLAGS.proc_side, 3],
        dtype=tfu.get_dtype())
    t.x = tfu.nhwc_to_std(t.x)

    is_absolute_model = FLAGS.scale_recovery in ('metrabs', )

    if is_absolute_model:
        intrinsics_tensor = tf.compat.v1.placeholder(shape=[None, 3, 3],
                                                     dtype=tf.float32)
        t.inv_intrinsics = tf.linalg.inv(intrinsics_tensor)
    else:
        intrinsics_tensor = None

    joint_info = data.datasets3d.get_dataset(FLAGS.dataset).joint_info

    if FLAGS.scale_recovery == 'metrabs':
        model.metrabs.build_metrabs_inference_model(joint_info, t)
    elif FLAGS.scale_recovery == 'metro':
        model.metro.build_metro_inference_model(joint_info, t)
    else:
        model.twofive.build_25d_inference_model(joint_info, t)

    # Convert to the original joint order as defined in the original datasets
    # (i.e. put the pelvis back to its place from the last position,
    # because this codebase normally uses the last position for the pelvis in all cases for
    # consistency)
    if FLAGS.dataset == 'many':
        selected_joint_ids = [23, *range(23)
                              ] if FLAGS.export_smpl else [*range(73)]
    elif FLAGS.dataset == 'h36m':
        selected_joint_ids = [16, *range(16)]
    else:
        assert FLAGS.dataset in ('mpi_inf_3dhp',
                                 'mupots') or 'muco' in FLAGS.dataset
        selected_joint_ids = [*range(14), 17, 14, 15]

    t.coords3d_pred = tf.gather(t.coords3d_pred, selected_joint_ids, axis=1)
    joint_info = joint_info.select_joints(selected_joint_ids)

    if FLAGS.load_path:
        load_path = util.ensure_absolute_path(FLAGS.load_path,
                                              FLAGS.checkpoint_dir)
    else:
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        load_path = checkpoint.model_checkpoint_path
    checkpoint_dir = os.path.dirname(load_path)
    out_path = util.ensure_absolute_path(FLAGS.export_file, checkpoint_dir)

    sm = tf.compat.v1.saved_model
    with tf.compat.v1.Session() as sess:
        saver = tf.compat.v1.train.Saver()
        saver.restore(sess, load_path)
        inputs = (dict(image=t.x, intrinsics=intrinsics_tensor)
                  if is_absolute_model else dict(image=t.x))

        signature_def = sm.signature_def_utils.predict_signature_def(
            inputs=inputs, outputs=dict(poses=t.coords3d_pred))
        os.mkdir(out_path)
        builder = sm.builder.SavedModelBuilder(out_path)
        builder.add_meta_graph_and_variables(
            sess, ['serve'],
            signature_def_map=dict(serving_default=signature_def))
        builder.save()

    tf.compat.v1.reset_default_graph()
    tf.compat.v1.enable_eager_execution()
    crop_model = tf.saved_model.load(out_path)
    shutil.rmtree(out_path)

    wrapper_class = (ExportedAbsoluteModel
                     if is_absolute_model else ExportedRootRelativeModel)
    wrapped_model = wrapper_class(crop_model, joint_info)
    tf.saved_model.save(wrapped_model, out_path)