コード例 #1
0
def train_and_evaluate(args):
    utils.download_data(args.data_dir, DATA_URL, unpack=True)
    train = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "train")
    val = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "val")
    train_dataset = (tf.data.Dataset.from_tensor_slices(train).repeat(
        args.num_epochs).shuffle(args.shuffle_buffer).batch(
            args.batch_size, drop_remainder=True))
    val_dataset = tf.data.Dataset.from_tensor_slices(val).batch(
        args.batch_size, drop_remainder=True)

    spdnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES)

    os.makedirs(args.job_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.job_dir, "afew-spdnet.ckpt")
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)
    log_dir = os.path.join(args.job_dir, "logs")
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    spdnet.fit(
        train_dataset,
        epochs=args.num_epochs,
        validation_data=val_dataset,
        callbacks=[cp_callback, tb_callback],
    )
    _, acc = spdnet.evaluate(val_dataset, verbose=2)
    print("Final accuracy: {}%".format(acc * 100))
コード例 #2
0
def train_and_evaluate(args):
    utils.download_data(args.data_dir, DATA_URL, unpack=True)
    train, val = prepare_data()

    train_dataset = (
        tf.data.Dataset.from_tensor_slices(train)
        .repeat(args.num_epochs)
        .shuffle(args.shuffle_buffer)
        .batch(args.batch_size, drop_remainder=True)
    )
    val_dataset = tf.data.Dataset.from_tensor_slices(val).batch(
        args.batch_size, drop_remainder=True
    )

    lienet = model.create_model(args.learning_rate, num_classes=G3D_CLASSES)
    checkpoint_path = os.path.join(args.job_dir, "g3d-lienet.ckpt")
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path, save_weights_only=True, verbose=1
    )
    lienet.fit(
        train_dataset,
        epochs=args.num_epochs,
        validation_data=val_dataset,
        callbacks=[cp_callback],
    )
    _, acc = lienet.evaluate(val_dataset, verbose=2)
    print("Final accuracy: {}%".format(acc * 100))