Ejemplo n.º 1
0
Archivo: eval.py Proyecto: wzzju/hapi
def main(FLAGS):
    device = set_device("gpu" if FLAGS.use_gpu else "cpu")
    fluid.enable_dygraph(device) if FLAGS.dynamic else None
    model = Seq2SeqAttModel(encoder_size=FLAGS.encoder_size,
                            decoder_size=FLAGS.decoder_size,
                            emb_dim=FLAGS.embedding_dim,
                            num_classes=FLAGS.num_classes)

    # yapf: disable
    inputs = [
        Input([None, 1, 48, 384], "float32", name="pixel"),
        Input([None, None], "int64", name="label_in")
    ]
    labels = [
        Input([None, None], "int64", name="label_out"),
        Input([None, None], "float32", name="mask")
    ]
    # yapf: enable

    model.prepare(loss_function=WeightCrossEntropy(),
                  metrics=SeqAccuracy(),
                  inputs=inputs,
                  labels=labels,
                  device=device)
    model.load(FLAGS.init_model)

    test_dataset = data.test()
    test_collate_fn = BatchCompose(
        [data.Resize(), data.Normalize(),
         data.PadTarget()])
    test_sampler = data.BatchSampler(test_dataset,
                                     batch_size=FLAGS.batch_size,
                                     drop_last=False,
                                     shuffle=False)
    test_loader = fluid.io.DataLoader(test_dataset,
                                      batch_sampler=test_sampler,
                                      places=device,
                                      num_workers=0,
                                      return_list=True,
                                      collate_fn=test_collate_fn)

    model.evaluate(eval_data=test_loader,
                   callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
Ejemplo n.º 2
0
def main(FLAGS):
    paddle.enable_static() if FLAGS.static else None
    device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu")

    # yapf: disable
    inputs = [
        Input([None,1,48,384], "float32", name="pixel"),
        Input([None, None], "int64", name="label_in"),
    ]
    labels = [
        Input([None, None], "int64", name="label_out"),
        Input([None, None], "float32", name="mask"),
    ]
    # yapf: enable

    model = paddle.Model(
        Seq2SeqAttModel(
            encoder_size=FLAGS.encoder_size,
            decoder_size=FLAGS.decoder_size,
            emb_dim=FLAGS.embedding_dim,
            num_classes=FLAGS.num_classes),
        inputs,
        labels)

    lr = FLAGS.lr
    if FLAGS.lr_decay_strategy == "piecewise_decay":
        learning_rate = fluid.layers.piecewise_decay(
            [200000, 250000], [lr, lr * 0.1, lr * 0.01])
    else:
        learning_rate = lr
    grad_clip = fluid.clip.GradientClipByGlobalNorm(FLAGS.gradient_clip)
    optimizer = fluid.optimizer.Adam(
        learning_rate=learning_rate,
        parameter_list=model.parameters(),
        grad_clip=grad_clip)

    model.prepare(optimizer, WeightCrossEntropy(), SeqAccuracy())

    train_dataset = data.train()
    train_collate_fn = BatchCompose(
        [data.Resize(), data.Normalize(), data.PadTarget()])
    train_sampler = data.BatchSampler(
        train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
    train_loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        places=device,
        num_workers=FLAGS.num_workers,
        return_list=True,
        collate_fn=train_collate_fn)
    test_dataset = data.test()
    test_collate_fn = BatchCompose(
        [data.Resize(), data.Normalize(), data.PadTarget()])
    test_sampler = data.BatchSampler(
        test_dataset,
        batch_size=FLAGS.batch_size,
        drop_last=False,
        shuffle=False)
    test_loader = paddle.io.DataLoader(
        test_dataset,
        batch_sampler=test_sampler,
        places=device,
        num_workers=0,
        return_list=True,
        collate_fn=test_collate_fn)

    model.fit(train_data=train_loader,
              eval_data=test_loader,
              epochs=FLAGS.epoch,
              save_dir=FLAGS.checkpoint_path,
              callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])