def __init__(self, metro_model, joint_info, joint_info2d=None, global_step=None): super().__init__(global_step) self.global_step = global_step self.joint_info = joint_info self.joint_info_2d = joint_info2d self.model = metro_model inp = keras.Input(shape=(None, None, 3), dtype=tfu.get_dtype()) self(inp, training=False)
def resnet(inp, n_outs, stride=16, centered_stride=False, global_pool=False, resnet_name='resnet_v2_50'): with slim.arg_scope(resnet_arg_scope()): if FLAGS.compatibility_mode: x = tf.cast(inp, tfu.get_dtype()) else: x = tf.cast(inp * 2 - 1, tfu.get_dtype()) resnet_fn = getattr(model.resnet_v2, resnet_name) xs, end_points = resnet_fn(x, num_classes=n_outs, is_training=tfu.is_training(), global_pool=global_pool, output_stride=stride, centered_stride=centered_stride) xs = [tf.cast(x, tf.float32) for x in xs] return xs
def __init__(self, crop_model, joint_info): super().__init__() self.crop_model = crop_model self.predict_crop = self.crop_model.signatures['serving_default'] self.joint_names = tf.Variable(np.array(joint_info.names), trainable=False) self.joint_edges = tf.Variable(np.array(joint_info.stick_figure_edges), trainable=False) @tf.function(input_signature=[ tf.TensorSpec(shape=(None, None, None, 3), dtype=tfu.get_dtype()), tf.TensorSpec(shape=(None, 3, 3), dtype=tf.float32) ]) def __call__(image, intrinsic_matrix): return self.predict_crop(image=image, intrinsics=intrinsic_matrix)['poses'] self.__call__ = __call__
def resnet(inp, n_out, stride=16, centered_stride=False, resnet_name='resnet_v2_50'): # if resnet_name == 'mobilenet': # return mobilenet(inp, n_out, stride) with slim.arg_scope(resnet_arg_scope()): x = tf.cast(inp, tfu.get_dtype()) resnet_fn = getattr(model.resnet_v2, resnet_name) x, end_points = resnet_fn(x, num_classes=n_out, is_training=tfu.is_training(), global_pool=False, output_stride=stride, centered_stride=centered_stride) x = tf.cast(x, tf.float32) return x
def mobilenet_preproc(x): return tf.cast(255, tfu.get_dtype()) * x
def tf_preproc(x): x = tf.cast(2, tfu.get_dtype()) * x - tf.cast(1, tfu.get_dtype()) return x
def caffe_preproc(x): mean_rgb = tf.convert_to_tensor(np.array([103.939, 116.779, 123.68]), tfu.get_dtype()) return tf.cast(255, tfu.get_dtype()) * x - mean_rgb
def torch_preproc(x): mean_rgb = tf.convert_to_tensor(np.array([0.485, 0.456, 0.406]), tfu.get_dtype()) stdev_rgb = tf.convert_to_tensor(np.array([0.229, 0.224, 0.225]), tfu.get_dtype()) normalized = (x - mean_rgb) / stdev_rgb return normalized
def export(): logging.info('Exporting model file.') tf.compat.v1.reset_default_graph() t = attrdict.AttrDict() t.x = tf.compat.v1.placeholder( shape=[None, FLAGS.proc_side, FLAGS.proc_side, 3], dtype=tfu.get_dtype()) t.x = tfu.nhwc_to_std(t.x) is_absolute_model = FLAGS.scale_recovery in ('metrabs', ) if is_absolute_model: intrinsics_tensor = tf.compat.v1.placeholder(shape=[None, 3, 3], dtype=tf.float32) t.inv_intrinsics = tf.linalg.inv(intrinsics_tensor) else: intrinsics_tensor = None joint_info = data.datasets3d.get_dataset(FLAGS.dataset).joint_info if FLAGS.scale_recovery == 'metrabs': model.metrabs.build_metrabs_inference_model(joint_info, t) elif FLAGS.scale_recovery == 'metro': model.metro.build_metro_inference_model(joint_info, t) else: model.twofive.build_25d_inference_model(joint_info, t) # Convert to the original joint order as defined in the original datasets # (i.e. put the pelvis back to its place from the last position, # because this codebase normally uses the last position for the pelvis in all cases for # consistency) if FLAGS.dataset == 'many': selected_joint_ids = [23, *range(23) ] if FLAGS.export_smpl else [*range(73)] elif FLAGS.dataset == 'h36m': selected_joint_ids = [16, *range(16)] else: assert FLAGS.dataset in ('mpi_inf_3dhp', 'mupots') or 'muco' in FLAGS.dataset selected_joint_ids = [*range(14), 17, 14, 15] t.coords3d_pred = tf.gather(t.coords3d_pred, selected_joint_ids, axis=1) joint_info = joint_info.select_joints(selected_joint_ids) if FLAGS.load_path: load_path = util.ensure_absolute_path(FLAGS.load_path, FLAGS.checkpoint_dir) else: checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) load_path = checkpoint.model_checkpoint_path checkpoint_dir = os.path.dirname(load_path) out_path = util.ensure_absolute_path(FLAGS.export_file, checkpoint_dir) sm = tf.compat.v1.saved_model with tf.compat.v1.Session() as sess: saver = tf.compat.v1.train.Saver() saver.restore(sess, load_path) inputs = (dict(image=t.x, intrinsics=intrinsics_tensor) if is_absolute_model else dict(image=t.x)) signature_def = sm.signature_def_utils.predict_signature_def( inputs=inputs, outputs=dict(poses=t.coords3d_pred)) os.mkdir(out_path) builder = sm.builder.SavedModelBuilder(out_path) builder.add_meta_graph_and_variables( sess, ['serve'], signature_def_map=dict(serving_default=signature_def)) builder.save() tf.compat.v1.reset_default_graph() tf.compat.v1.enable_eager_execution() crop_model = tf.saved_model.load(out_path) shutil.rmtree(out_path) wrapper_class = (ExportedAbsoluteModel if is_absolute_model else ExportedRootRelativeModel) wrapped_model = wrapper_class(crop_model, joint_info) tf.saved_model.save(wrapped_model, out_path)