def predict(): dataset3d = data.datasets3d.get_dataset(FLAGS.dataset) backbone = backbones.builder.build_backbone() model_class = getattr(models, FLAGS.model_class) trainer_class = getattr(models, FLAGS.model_class + 'Trainer') model_joint_info = data.datasets3d.get_joint_info(FLAGS.model_joints) if FLAGS.model_class.startswith('Model25D'): bone_dataset = data.datasets3d.get_dataset(FLAGS.bone_length_dataset) bone_lengths = (bone_dataset.trainval_bones if FLAGS.train_on == 'trainval' else bone_dataset.train_bones) extra_args = [bone_lengths] else: extra_args = [] model = model_class(backbone, model_joint_info, *extra_args) trainer = trainer_class(model, model_joint_info) trainer.predict_tensor_names = [ 'coords3d_rel_pred', 'coords3d_pred_abs', 'rot_to_world', 'cam_loc', 'image_path' ] if FLAGS.viz: trainer.predict_tensor_names += ['image', 'coords3d_true'] ckpt = tf.train.Checkpoint(model=model) ckpt_manager = tf.train.CheckpointManager(ckpt, FLAGS.checkpoint_dir, None) restore_if_ckpt_available(ckpt, ckpt_manager, expect_partial=True) examples3d_test = get_examples(dataset3d, tfu.TEST, FLAGS) data_test = build_dataflow(examples3d_test, data.data_loading.load_and_transform3d, (dataset3d.joint_info, TEST), TEST, batch_size=FLAGS.batch_size_test, n_workers=FLAGS.workers) n_predict_steps = int(np.ceil( len(examples3d_test) / FLAGS.batch_size_test)) r = trainer.predict(data_test, verbose=1 if sys.stdout.isatty() else 0, steps=n_predict_steps) r = attrdict.AttrDict(r) util.ensure_path_exists(FLAGS.pred_path) logger.info(f'Saving predictions to {FLAGS.pred_path}') try: coords3d_pred = r.coords3d_pred_abs except AttributeError: coords3d_pred = r.coords3d_rel_pred coords3d_pred_world = tf.einsum('nCc, njc->njC', r.rot_to_world, coords3d_pred) + tf.expand_dims( r.cam_loc, 1) coords3d_pred_world = models.util.select_skeleton( coords3d_pred_world, model_joint_info, FLAGS.output_joints).numpy() np.savez(FLAGS.pred_path, image_path=r.image_path, coords3d_pred_world=coords3d_pred_world)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--input-model-path', type=str, required=True) parser.add_argument('--output-model-path', type=str, required=True) parser.add_argument('--detector-path', type=str) parser.add_argument('--bone-length-dataset', type=str) parser.add_argument('--rot-aug', type=float, default=25) parser.add_argument('--rot-aug-linspace-noend', action=options.BoolAction) parser.add_argument('--crop-side', type=int, default=256) parser.add_argument('--detector-flip-vertical-too', action=options.BoolAction) options.initialize(parser) pose_estimator = Pose3dEstimator() tf.saved_model.save( pose_estimator, FLAGS.output_model_path, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) logger.info(f'Full image model has been exported to {FLAGS.output_model_path}')
def main(): initialize() model = tf.saved_model.load(FLAGS.model_path) skeleton = 'h36m_17' joint_names = model.per_skeleton_joint_names[skeleton].numpy().astype(str) joint_edges = model.per_skeleton_joint_edges[skeleton].numpy() predict_fn = functools.partial( model.estimate_poses_batched, internal_batch_size=0, num_aug=FLAGS.num_aug, antialias_factor=2, skeleton=skeleton) viz = poseviz.PoseViz( joint_names, joint_edges, write_video=bool(FLAGS.out_video_dir), world_up=(0, 0, 1), ground_plane_height=0, queue_size=2 * FLAGS.batch_size) if FLAGS.viz else None image_relpaths_all = [] coords_all = [] for i_subject in (9, 11): for activity_name, camera_id in itertools.product( data.h36m.get_activity_names(i_subject), range(4)): if FLAGS.viz: viz.new_sequence() if FLAGS.out_video_dir: viz.start_new_video( f'{FLAGS.out_video_dir}/S{i_subject}/{activity_name}.{camera_id}.mp4', fps=max(50 / FLAGS.frame_step, 2)) logger.info(f'Predicting S{i_subject} {activity_name} {camera_id}...') frame_relpaths, bboxes, camera = get_sequence(i_subject, activity_name, camera_id) frame_paths = [f'{paths.DATA_ROOT}/{p}' for p in frame_relpaths] box_ds = tf.data.Dataset.from_tensor_slices(bboxes) ds, frame_batches_cpu = video_io.image_files_as_tf_dataset( frame_paths, extra_data=box_ds, batch_size=FLAGS.batch_size, tee_cpu=FLAGS.viz) coords3d_pred_world = predict_sequence(predict_fn, ds, frame_batches_cpu, camera, viz) image_relpaths_all.append(frame_relpaths) coords_all.append(coords3d_pred_world) np.savez( FLAGS.output_path, image_path=np.concatenate(image_relpaths_all, axis=0), coords3d_pred_world=np.concatenate(coords_all, axis=0)) if FLAGS.viz: viz.close()
def initialize(args=None): options.initialize_with_logfiles(get_parser(), args) logger.info(f'-- Starting --') logger.info(f'Host: {socket.gethostname()}') logger.info(f'Process id (pid): {os.getpid()}') if FLAGS.comment: logger.info(f'Comment: {FLAGS.comment}') logger.info(f'Raw command: {" ".join(map(shlex.quote, sys.argv))}') logger.info(f'Parsed flags: {FLAGS}') tfu.set_data_format(FLAGS.data_format) tfu.set_dtype(tf.float32 if FLAGS.dtype == 'float32' else tf.float16) if FLAGS.batch_size_test is None: FLAGS.batch_size_test = FLAGS.batch_size if FLAGS.checkpoint_dir is None: FLAGS.checkpoint_dir = FLAGS.logdir FLAGS.checkpoint_dir = util.ensure_absolute_path( FLAGS.checkpoint_dir, root=f'{paths.DATA_ROOT}/experiments') os.makedirs(FLAGS.checkpoint_dir, exist_ok=True) if not FLAGS.pred_path: FLAGS.pred_path = f'predictions_{FLAGS.dataset}.npz' base = os.path.dirname( FLAGS.load_path) if FLAGS.load_path else FLAGS.checkpoint_dir FLAGS.pred_path = util.ensure_absolute_path(FLAGS.pred_path, base) if FLAGS.bone_length_dataset is None: FLAGS.bone_length_dataset = FLAGS.dataset if FLAGS.model_joints is None: FLAGS.model_joints = FLAGS.dataset if FLAGS.output_joints is None: FLAGS.output_joints = FLAGS.dataset if FLAGS.load_path: if FLAGS.load_path.endswith('.index') or FLAGS.load_path.endswith( '.meta'): FLAGS.load_path = os.path.splitext(FLAGS.load_path)[0] FLAGS.load_path = util.ensure_absolute_path(FLAGS.load_path, FLAGS.checkpoint_dir) tf.random.set_seed(FLAGS.seed) if FLAGS.viz: plt.switch_backend('TkAgg') FLAGS.backbone = FLAGS.backbone.replace('_', '-') for gpu in tf.config.experimental.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(gpu, True) if FLAGS.dtype == 'float16': tf.keras.mixed_precision.set_global_policy('mixed_float16')
def train(): strategy = tf.distribute.MirroredStrategy( ) if FLAGS.multi_gpu else dummy_strategy() n_repl = strategy.num_replicas_in_sync ####### # TRAINING DATA ####### dataset3d = data.datasets3d.get_dataset(FLAGS.dataset) joint_info3d = dataset3d.joint_info examples3d = get_examples(dataset3d, TRAIN, FLAGS) dataset2d = data.datasets2d.get_dataset(FLAGS.dataset2d) joint_info2d = dataset2d.joint_info examples2d = [*dataset2d.examples[TRAIN], *dataset2d.examples[VALID]] if 'many' in FLAGS.dataset: if 'aist' in FLAGS.dataset: dataset_section_names = 'h36m muco-3dhp surreal panoptic aist_ sailvos'.split( ) roundrobin_sizes = [8, 8, 8, 8, 8, 8] roundrobin_sizes = [x * 2 for x in roundrobin_sizes] else: dataset_section_names = 'h36m muco-3dhp panoptic surreal sailvos'.split( ) roundrobin_sizes = [9, 9, 9, 9, 9] example_sections = build_dataset_sections(examples3d, dataset_section_names) else: example_sections = [examples3d] roundrobin_sizes = [FLAGS.batch_size] n_completed_steps = get_n_completed_steps(FLAGS.checkpoint_dir, FLAGS.load_path) rng = np.random.RandomState(FLAGS.seed) data2d = build_dataflow(examples2d, data.data_loading.load_and_transform2d, (joint_info2d, TRAIN), TRAIN, batch_size=FLAGS.batch_size_2d * n_repl, n_workers=FLAGS.workers, rng=util.new_rng(rng), n_completed_steps=n_completed_steps, n_total_steps=FLAGS.training_steps) data3d = build_dataflow(example_sections, data.data_loading.load_and_transform3d, (joint_info3d, TRAIN), tfu.TRAIN, batch_size=sum(roundrobin_sizes) // 2 * n_repl, n_workers=FLAGS.workers, rng=util.new_rng(rng), n_completed_steps=n_completed_steps, n_total_steps=FLAGS.training_steps, roundrobin_sizes=roundrobin_sizes) data_train = tf.data.Dataset.zip((data3d, data2d)) data_train = data_train.map(lambda batch3d, batch2d: { **batch3d, **batch2d }) if not FLAGS.multi_gpu: data_train = data_train.apply( tf.data.experimental.prefetch_to_device('GPU:0', 2)) opt = tf.data.Options() opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA data_train = data_train.with_options(opt) ####### # VALIDATION DATA ####### examples3d_val = get_examples(dataset3d, VALID, FLAGS) if FLAGS.validate_period: data_val = build_dataflow(examples3d_val, data.data_loading.load_and_transform3d, (joint_info3d, VALID), VALID, batch_size=FLAGS.batch_size_test * n_repl, n_workers=FLAGS.workers, rng=util.new_rng(rng)) data_val = data_val.with_options(opt) validation_steps = int( np.ceil(len(examples3d_val) / (FLAGS.batch_size_test * n_repl))) else: data_val = None validation_steps = None ####### # MODEL ####### with strategy.scope(): global_step = tf.Variable(n_completed_steps, dtype=tf.int32, trainable=False) backbone = backbones.builder.build_backbone() model_class = getattr(models, FLAGS.model_class) trainer_class = getattr(models, FLAGS.model_class + 'Trainer') bone_lengths = (dataset3d.trainval_bones if FLAGS.train_on == 'trainval' else dataset3d.train_bones) extra_args = [bone_lengths ] if FLAGS.model_class.startswith('Model25D') else [] model = model_class(backbone, joint_info3d, *extra_args) trainer = trainer_class(model, joint_info3d, joint_info2d, global_step) trainer.compile(optimizer=build_optimizer(global_step, n_repl)) model.optimizer = trainer.optimizer ####### # CHECKPOINTING ####### ckpt = tf.train.Checkpoint(model=model) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=FLAGS.checkpoint_dir, max_to_keep=2, step_counter=global_step, checkpoint_interval=FLAGS.checkpoint_period) restore_if_ckpt_available(ckpt, ckpt_manager, global_step, FLAGS.init_path) trainer.optimizer.iterations.assign(n_completed_steps) ####### # CALLBACKS ####### cbacks = [ keras.callbacks.LambdaCallback( on_train_begin=lambda logs: trainer._train_counter.assign( n_completed_steps), on_train_batch_end=lambda batch, logs: ckpt_manager.save( global_step)), callbacks.ProgbarCallback(n_completed_steps, FLAGS.training_steps), callbacks.WandbCallback(global_step), callbacks.TensorBoardCallback(global_step) ] if FLAGS.finetune_in_inference_mode: switch_step = FLAGS.training_steps - FLAGS.finetune_in_inference_mode c = callbacks.SwitchToInferenceModeCallback(global_step, switch_step) cbacks.append(c) ####### # FITTING ####### try: trainer.fit(data_train, steps_per_epoch=1, initial_epoch=n_completed_steps, epochs=FLAGS.training_steps, verbose=1 if sys.stdout.isatty() else 0, callbacks=cbacks, validation_data=data_val, validation_freq=FLAGS.validate_period, validation_steps=validation_steps) model.save(f'{FLAGS.checkpoint_dir}/model', include_optimizer=False, overwrite=True, options=tf.saved_model.SaveOptions( experimental_custom_gradients=True)) except KeyboardInterrupt: logger.info('Training interrupted.') except tf.errors.ResourceExhaustedError: logger.info('Resource Exhausted!') finally: ckpt_manager.save(global_step, check_interval=False) logger.info('Saved checkpoint.')