コード例 #1
0
ファイル: vgg16.py プロジェクト: hellowaywewe/tinyms
    model.compile(loss_fn=net_loss,
                  optimizer=net_opt,
                  metrics={"Accuracy": Accuracy()})

    epoch_size = args_opt.epoch_size
    batch_size = args_opt.batch_size
    cifar10_path = args_opt.dataset_path
    save_checkpoint_epochs = args_opt.save_checkpoint_epochs
    dataset_sink_mode = not args_opt.device_target == "CPU"
    if args_opt.do_eval:  # as for evaluation, users could use model.eval
        ds_eval = create_dataset(cifar10_path,
                                 batch_size=batch_size,
                                 is_training=False)
        if args_opt.load_pretrained == 'local':
            if args_opt.checkpoint_path:
                model.load_checkpoint(args_opt.checkpoint_path)
        acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
        print("============== Accuracy:{} ==============".format(acc))
    else:  # as for train, users could use model.train
        ds_train = create_dataset(cifar10_path, batch_size=batch_size)
        ckpoint_cb = ModelCheckpoint(
            prefix="vgg16_cifar10",
            config=CheckpointConfig(
                save_checkpoint_steps=save_checkpoint_epochs *
                ds_train.get_dataset_size(),
                keep_checkpoint_max=10))
        model.train(epoch_size,
                    ds_train,
                    callbacks=[ckpoint_cb, LossMonitor()],
                    dataset_sink_mode=dataset_sink_mode)
コード例 #2
0
ファイル: deepfm.py プロジェクト: hellowaywewe/tinyms
    net = DeepFM(field_size=39,
                 vocab_size=184965,
                 embed_size=80,
                 convert_dtype=True)
    # build train network
    train_net = DeepFMTrainModel(DeepFMWithLoss(net))
    # build eval network
    eval_net = DeepFMEvalModel(net)
    # build model
    model = Model(train_net)
    # loss/ckpt/metric callbacks
    loss_tm = LossTimeMonitorV2()
    config_ckpt = CheckpointConfig(save_checkpoint_steps=data_size // 100,
                                   keep_checkpoint_max=10)
    model_ckpt = ModelCheckpoint(prefix='deepfm',
                                 directory=checkpoint_dir,
                                 config=config_ckpt)
    auc_metric = AUCMetric()

    model.compile(eval_network=eval_net,
                  metrics={"auc_metric": auc_metric},
                  amp_level='O0')
    print("====== start train model ======", flush=True)
    model.train(epoch=epoch_size,
                train_dataset=train_ds,
                callbacks=[loss_tm, model_ckpt],
                dataset_sink_mode=dataset_sink_mode)
    print("====== start eval model ======", flush=True)
    acc = model.eval(eval_ds)
    print("====== eval acc: {} ======".format(acc), flush=True)