def get_train_hook(self):
     hooks = []
     callback = LossCallBack()
     hooks.append(callback)
     if int(os.getenv('DEVICE_ID')) == 0:
         pass
     return hooks
Beispiel #2
0
def test_train(configure):
    """
    test_train
    """
    data_path = configure.data_path
    batch_size = configure.batch_size
    epochs = configure.epochs
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=epochs,
                              batch_size=batch_size)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))

    net_builder = ModelBuilder()
    train_net, _ = net_builder.get_net(configure)
    train_net.set_train()

    model = Model(train_net)
    callback = LossCallBack(config=configure)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=configure.ckpt_path,
                                 config=ckptconfig)
    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), callback,
                    ckpoint_cb
                ])
Beispiel #3
0
def train_and_eval(config):
    """
    test_train_eval
    """
    np.random.seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
                              batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
    ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
                             batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path, config=ckptconfig)
    out = model.eval(ds_eval)
    print("=====" * 5 + "model.eval() initialized: {}".format(out))
    model.train(epochs, ds_train,
                callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
def test_train_eval():
    """
    test_train_eval
    """
    config = WideDeepConfig()
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size,
                              data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
    ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size,
                             data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
    model.train(epochs, ds_train,
                callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback])
    eval_values = list(eval_callback.eval_values)
    assert eval_values[0] > 0.78
Beispiel #5
0
def test_train_eval():
    """
    test_train_eval
    """
    np.random.seed(1000)
    config = WideDeepConfig()
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              rank_id=get_rank(),
                              rank_size=get_group_size())
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             rank_id=get_rank(),
                             rank_size=get_group_size())
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)
    out = model.eval(ds_eval)
    print("=====" * 5 + "model.eval() initialized: {}".format(out))
    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), eval_callback,
                    callback, ckpoint_cb
                ])
    expect_out0 = [0.792634, 0.799862, 0.803324]
    expect_out6 = [0.796580, 0.803908, 0.807262]
    if get_rank() == 0:
        assert np.allclose(eval_callback.eval_values, expect_out0)
    if get_rank() == 6:
        assert np.allclose(eval_callback.eval_values, expect_out6)
Beispiel #6
0
def train_and_eval(config):
    """
    train_and_eval
    """
    set_seed(1000)
    data_path = config.data_path
    epochs = config.epochs
    print("epochs is {}".format(epochs))

    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=config.batch_size,
                              is_tf_dataset=config.is_tf_dataset,
                              rank_id=get_rank(),
                              rank_size=get_group_size())
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=config.batch_size,
                             is_tf_dataset=config.is_tf_dataset,
                             rank_id=get_rank(),
                             rank_size=get_group_size())

    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
    callback = LossCallBack(config)
    # Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size() * config.epochs,
        keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if int(get_rank()) == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                sink_size=ds_train.get_dataset_size())
Beispiel #7
0
def test_train_eval(config):
    """
    test_train_eval
    """
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    sparse = config.sparse
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              data_type=dataset_type)
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)

    out = model.eval(ds_eval)
    print("=====" * 5 + "model.eval() initialized: {}".format(out))
    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), eval_callback,
                    callback, ckpoint_cb
                ],
                dataset_sink_mode=(not sparse))
def train_and_eval(config):
    """
    test_train_eval
    """
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    host_device_mix = bool(config.host_device_mix)
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        de.config.set_seed(1)
        ds_train = create_dataset(data_path, train_mode=True, epochs=1,
                                  batch_size=batch_size*get_group_size(), data_type=dataset_type)
        ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
                                 batch_size=batch_size*get_group_size(), data_type=dataset_type)
    else:
        ds_train = create_dataset(data_path, train_mode=True, epochs=1,
                                  batch_size=batch_size, rank_id=get_rank(),
                                  rank_size=get_group_size(), data_type=dataset_type)
        ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
                                 batch_size=batch_size, rank_id=get_rank(),
                                 rank_size=get_group_size(), data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path, config=ckptconfig)
    context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
    callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
    if not host_device_mix:
        callback_list.append(ckpoint_cb)
    model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix))
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path, train_mode=True, epochs=1,
                              batch_size=batch_size, rank_id=get_rank(),
                              rank_size=get_group_size(), data_type=dataset_type)
    ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
                             batch_size=batch_size, rank_id=get_rank(),
                             rank_size=get_group_size(), data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
                                 config=ckptconfig)
    out = model.eval(ds_eval)
    print("=====" * 5 + "model.eval() initialized: {}".format(out))
    callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs, ds_train,
                callbacks=callback_list,
                sink_size=ds_train.get_dataset_size())
Beispiel #10
0
def train_and_eval(config):
    """
    train_and_eval.
    """
    data_path = config.data_path
    epochs = config.epochs
    print("epochs is {}".format(epochs))

    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=config.batch_size,
                              is_tf_dataset=config.is_tf_dataset)
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=config.batch_size,
                             is_tf_dataset=config.is_tf_dataset)

    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
    callback = LossCallBack(config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)

    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), eval_callback,
                    callback, ckpoint_cb
                ])
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    cache_enable = config.vocab_cache_size > 0
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path, train_mode=True, epochs=1,
                              batch_size=batch_size, data_type=dataset_type)
    ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
                             batch_size=batch_size, data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
    callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]

    model.train(epochs, ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(parameter_server and cache_enable))
Beispiel #12
0
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    if cache_enable:
        config.full_batch = True
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        ds.config.set_seed(1)
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size * get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size * get_group_size(),
                                 data_type=dataset_type)
    else:
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size,
                                  rank_id=get_rank(),
                                  rank_size=get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size,
                                 rank_id=get_rank(),
                                 rank_size=get_group_size(),
                                 data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    if cache_enable:
        config.stra_ckpt = os.path.join(
            config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
        context.set_auto_parallel_context(
            strategy_ckpt_save_file=config.stra_ckpt)

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    if _is_role_worker():
        if cache_enable:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
                keep_checkpoint_max=1,
                integrated_save=False)
        else:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size(),
                keep_checkpoint_max=5)
    else:
        ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
                                      keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=bool(parameter_server and cache_enable))
def train_and_eval(config):
    """
    test_train_eval
    """
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        de.config.set_seed(1)
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=epochs,
                                  batch_size=batch_size * get_group_size())
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=epochs + 1,
                                 batch_size=batch_size * get_group_size())
    else:
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=epochs,
                                  batch_size=batch_size,
                                  rank_id=get_rank(),
                                  rank_size=get_group_size())
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=epochs + 1,
                                 batch_size=batch_size,
                                 rank_id=get_rank(),
                                 rank_size=get_group_size())
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)
    context.set_auto_parallel_context(
        strategy_ckpt_save_file="./strategy_train.ckpt")
    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), eval_callback,
                    callback, ckpoint_cb
                ])
Beispiel #14
0
    opt = SGD(params=net_with_loss.trainable_params(),
              learning_rate=lr,
              momentum=config.momentum,
              weight_decay=config.weight_decay,
              loss_scale=config.loss_scale)

    if device_num > 1:
        net = TrainOneStepCell(net_with_loss,
                               opt,
                               reduce_flag=True,
                               mean=True,
                               degree=device_num)
    else:
        net = TrainOneStepCell(net_with_loss, opt)

    loss_cb = LossCallBack(data_size=train_dataset_batch_num, logger=logger)

    cb = [loss_cb]

    if config.save_checkpoint:
        ckptconfig = CheckpointConfig(
            save_checkpoint_steps=train_dataset_batch_num,
            keep_checkpoint_max=config.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix='AVA',
                                     directory=save_checkpoint_path,
                                     config=ckptconfig)
        cb += [ckpoint_cb]

    model = Model(net)

    print("save configs...")
Beispiel #15
0
def train_eval(config):
    """
    test evaluate
    """
    data_path = config.data_path + config.dataset_type
    ckpt_path = config.ckpt_path
    epochs = config.epochs
    batch_size = config.batch_size
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5

    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             data_type=dataset_type)
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()
    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()

    train_model = Model(train_net)
    train_callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)
    train_model.train(epochs,
                      ds_train,
                      callbacks=[
                          TimeMonitor(ds_train.get_dataset_size()),
                          train_callback, ckpoint_cb
                      ])

    # data download
    print('Download data from modelarts server to obs.')
    mox.file.copy_parallel(src_url=config.ckpt_path, dst_url=config.train_url)

    param_dict = load_checkpoint(find_ckpt(ckpt_path))
    load_param_into_net(eval_net, param_dict)

    auc_metric = AUCMetric()
    eval_model = Model(train_net,
                       eval_network=eval_net,
                       metrics={"auc": auc_metric})
    eval_callback = EvalCallBack(eval_model, ds_eval, auc_metric, config)

    eval_model.eval(ds_eval, callbacks=eval_callback)
def train_and_eval(config):
    """
    test_train_eval
    """
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    host_device_mix = bool(config.host_device_mix)
    sparse = config.sparse
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        ds.config.set_seed(1)
        if config.field_slice:
            compute_manual_shape(config, get_group_size())
            ds_train = create_dataset(data_path,
                                      train_mode=True,
                                      epochs=1,
                                      batch_size=batch_size * get_group_size(),
                                      data_type=dataset_type,
                                      manual_shape=config.manual_shape,
                                      target_column=config.field_size)
            ds_eval = create_dataset(data_path,
                                     train_mode=False,
                                     epochs=1,
                                     batch_size=batch_size * get_group_size(),
                                     data_type=dataset_type,
                                     manual_shape=config.manual_shape,
                                     target_column=config.field_size)
        else:
            ds_train = create_dataset(data_path,
                                      train_mode=True,
                                      epochs=1,
                                      batch_size=batch_size * get_group_size(),
                                      data_type=dataset_type)
            ds_eval = create_dataset(data_path,
                                     train_mode=False,
                                     epochs=1,
                                     batch_size=batch_size * get_group_size(),
                                     data_type=dataset_type)
    else:
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size,
                                  rank_id=get_rank(),
                                  rank_size=get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size,
                                 rank_id=get_rank(),
                                 rank_size=get_group_size(),
                                 data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    # Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
    config.stra_ckpt = os.path.join(
        config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config, per_print_times=20)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
        keep_checkpoint_max=5,
        integrated_save=False)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=os.path.join(
                                     config.ckpt_path,
                                     'ckpt_' + str(get_rank())),
                                 config=ckptconfig)

    context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if not host_device_mix:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(not sparse))