def create_openpose_singlenet(pretrained=False): pretrained_url = "https://github.com/michalfaber/tensorflow_Realtime_Multi-Person_Pose_Estimation/releases/download/v1.0/openpose_singlenet_v1.zip" model = OpenPoseSingleNet(in_chs=[3]) model.build([tf.TensorShape((None, 224, 224, 3))]) if pretrained: path = download_checkpoint(pretrained_url) model.load_weights(path) return model
def create_mobilenet_v3_224_1x(pretrained=False): pretrained_url = "https://github.com/michalfaber/tf_netbuilder/releases/download/v1.0/mobilenet_v3_224_1_0.zip" model = MobilenetV3(in_chs=3, num_classes=1001) model.build([tf.TensorShape((None, 224, 224, 3))]) if pretrained: path = download_checkpoint(pretrained_url) model.load_weights(path) return model
def create_openpose_2branches_vgg(pretrained=False, training=False): pretrained_url = "https://github.com/michalfaber/tensorflow_Realtime_Multi-Person_Pose_Estimation/releases/download/v1.0/openpose_2br_vgg.zip" if training: model = OpenPose2BranchesVGG(in_chs=[3, 12, 6], training=training) model.build([ tf.TensorShape((None, None, None, 3)), tf.TensorShape((None, None, None, 12)), tf.TensorShape((None, None, None, 6)) #TODO ]) else: model = OpenPose2BranchesVGG(in_chs=[3], training=training) model.build([tf.TensorShape((None, None, None, 3))]) if pretrained: path = download_checkpoint(pretrained_url) model.load_weights(path) return model
model = create_openpose_singlenet(pretrained=False) optimizer = Adam(lr) # loading previous state if required ckpt = tf.train.Checkpoint(step=tf.Variable(0), epoch=tf.Variable(0), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, checkpoints_folder, max_to_keep=3) ckpt.restore(manager.latest_checkpoint) last_step = int(ckpt.step) last_epoch = int(ckpt.epoch) if manager.latest_checkpoint: print(f"Restored from {manager.latest_checkpoint}") print(f"Resumed from epoch {last_epoch}, step {last_step}") else: print("Initializing from scratch.") path = download_checkpoint(pretrained_mobilenet_v3_url) model.load_weights(path).expect_partial() # training loop train(ds_train, ds_val, model, optimizer, ckpt, last_epoch, last_step, max_epochs, steps_per_epoch)