Ejemplo n.º 1
0
def create_openpose_singlenet(pretrained=False):

    pretrained_url = "https://github.com/michalfaber/tensorflow_Realtime_Multi-Person_Pose_Estimation/releases/download/v1.0/openpose_singlenet_v1.zip"

    model = OpenPoseSingleNet(in_chs=[3])
    model.build([tf.TensorShape((None, 224, 224, 3))])

    if pretrained:
        path = download_checkpoint(pretrained_url)
        model.load_weights(path)

    return model
Ejemplo n.º 2
0
def create_mobilenet_v3_224_1x(pretrained=False):

    pretrained_url = "https://github.com/michalfaber/tf_netbuilder/releases/download/v1.0/mobilenet_v3_224_1_0.zip"

    model = MobilenetV3(in_chs=3, num_classes=1001)

    model.build([tf.TensorShape((None, 224, 224, 3))])

    if pretrained:
        path = download_checkpoint(pretrained_url)
        model.load_weights(path)

    return model
def create_openpose_2branches_vgg(pretrained=False, training=False):

    pretrained_url = "https://github.com/michalfaber/tensorflow_Realtime_Multi-Person_Pose_Estimation/releases/download/v1.0/openpose_2br_vgg.zip"

    if training:
        model = OpenPose2BranchesVGG(in_chs=[3, 12, 6], training=training)
        model.build([
            tf.TensorShape((None, None, None, 3)),
            tf.TensorShape((None, None, None, 12)),
            tf.TensorShape((None, None, None, 6))  #TODO
        ])
    else:
        model = OpenPose2BranchesVGG(in_chs=[3], training=training)
        model.build([tf.TensorShape((None, None, None, 3))])

    if pretrained:
        path = download_checkpoint(pretrained_url)
        model.load_weights(path)

    return model
    model = create_openpose_singlenet(pretrained=False)
    optimizer = Adam(lr)

    # loading previous state if required

    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                               epoch=tf.Variable(0),
                               optimizer=optimizer,
                               net=model)
    manager = tf.train.CheckpointManager(ckpt,
                                         checkpoints_folder,
                                         max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    last_step = int(ckpt.step)
    last_epoch = int(ckpt.epoch)

    if manager.latest_checkpoint:
        print(f"Restored from {manager.latest_checkpoint}")
        print(f"Resumed from epoch {last_epoch}, step {last_step}")
    else:
        print("Initializing from scratch.")

        path = download_checkpoint(pretrained_mobilenet_v3_url)
        model.load_weights(path).expect_partial()

    # training loop

    train(ds_train, ds_val, model, optimizer, ckpt, last_epoch, last_step,
          max_epochs, steps_per_epoch)