def get_train_hook(self): hooks = [] callback = LossCallBack() hooks.append(callback) if int(os.getenv('DEVICE_ID')) == 0: pass return hooks
def test_train(configure): """ test_train """ data_path = configure.data_path batch_size = configure.batch_size epochs = configure.epochs ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) print("ds_train.size: {}".format(ds_train.get_dataset_size())) net_builder = ModelBuilder() train_net, _ = net_builder.get_net(configure) train_net.set_train() model = Model(train_net) callback = LossCallBack(config=configure) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb ])
def train_and_eval(config): """ test_train_eval """ np.random.seed(1000) data_path = config.data_path batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
def test_train_eval(): """ test_train_eval """ config = WideDeepConfig() data_path = config.data_path batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]) eval_values = list(eval_callback.eval_values) assert eval_values[0] > 0.78
def test_train_eval(): """ test_train_eval """ np.random.seed(1000) config = WideDeepConfig() data_path = config.data_path batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb ]) expect_out0 = [0.792634, 0.799862, 0.803324] expect_out6 = [0.796580, 0.803908, 0.807262] if get_rank() == 0: assert np.allclose(eval_callback.eval_values, expect_out0) if get_rank() == 6: assert np.allclose(eval_callback.eval_values, expect_out6)
def train_and_eval(config): """ train_and_eval """ set_seed(1000) data_path = config.data_path epochs = config.epochs print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset, rank_id=get_rank(), rank_size=get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config) # Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size() * config.epochs, keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) callback_list = [ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback ] if int(get_rank()) == 0: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, sink_size=ds_train.get_dataset_size())
def test_train_eval(config): """ test_train_eval """ data_path = config.data_path batch_size = config.batch_size epochs = config.epochs sparse = config.sparse if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb ], dataset_sink_mode=(not sparse))
def train_and_eval(config): """ test_train_eval """ data_path = config.data_path batch_size = config.batch_size epochs = config.epochs if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 host_device_mix = bool(config.host_device_mix) print("epochs is {}".format(epochs)) if config.full_batch: context.set_auto_parallel_context(full_batch=True) de.config.set_seed(1) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size*get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size*get_group_size(), data_type=dataset_type) else: ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] if not host_device_mix: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix))
def train_and_eval(config): """ test_train_eval """ set_seed(1000) data_path = config.data_path batch_size = config.batch_size epochs = config.epochs if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] if get_rank() == 0: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, sink_size=ds_train.get_dataset_size())
def train_and_eval(config): """ train_and_eval. """ data_path = config.data_path epochs = config.epochs print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb ])
def train_and_eval(config): """ test_train_eval """ set_seed(1000) data_path = config.data_path batch_size = config.batch_size epochs = config.epochs if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 parameter_server = bool(config.parameter_server) cache_enable = config.vocab_cache_size > 0 print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb] model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(parameter_server and cache_enable))
def train_and_eval(config): """ test_train_eval """ set_seed(1000) data_path = config.data_path batch_size = config.batch_size epochs = config.epochs if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 parameter_server = bool(config.parameter_server) if cache_enable: config.full_batch = True print("epochs is {}".format(epochs)) if config.full_batch: context.set_auto_parallel_context(full_batch=True) ds.config.set_seed(1) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type) else: ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) if cache_enable: config.stra_ckpt = os.path.join( config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") context.set_auto_parallel_context( strategy_ckpt_save_file=config.stra_ckpt) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) if _is_role_worker(): if cache_enable: ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size() * epochs, keep_checkpoint_max=1, integrated_save=False) else: ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) else: ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) callback_list = [ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback ] if get_rank() == 0: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=bool(parameter_server and cache_enable))
def train_and_eval(config): """ test_train_eval """ data_path = config.data_path batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) if config.full_batch: context.set_auto_parallel_context(full_batch=True) de.config.set_seed(1) ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size * get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size * get_group_size()) else: ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) context.set_auto_parallel_context( strategy_ckpt_save_file="./strategy_train.ckpt") model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb ])
opt = SGD(params=net_with_loss.trainable_params(), learning_rate=lr, momentum=config.momentum, weight_decay=config.weight_decay, loss_scale=config.loss_scale) if device_num > 1: net = TrainOneStepCell(net_with_loss, opt, reduce_flag=True, mean=True, degree=device_num) else: net = TrainOneStepCell(net_with_loss, opt) loss_cb = LossCallBack(data_size=train_dataset_batch_num, logger=logger) cb = [loss_cb] if config.save_checkpoint: ckptconfig = CheckpointConfig( save_checkpoint_steps=train_dataset_batch_num, keep_checkpoint_max=config.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix='AVA', directory=save_checkpoint_path, config=ckptconfig) cb += [ckpoint_cb] model = Model(net) print("save configs...")
def train_eval(config): """ test evaluate """ data_path = config.data_path + config.dataset_type ckpt_path = config.ckpt_path epochs = config.epochs batch_size = config.batch_size if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, data_type=dataset_type) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() train_model = Model(train_net) train_callback = LossCallBack(config=config) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) train_model.train(epochs, ds_train, callbacks=[ TimeMonitor(ds_train.get_dataset_size()), train_callback, ckpoint_cb ]) # data download print('Download data from modelarts server to obs.') mox.file.copy_parallel(src_url=config.ckpt_path, dst_url=config.train_url) param_dict = load_checkpoint(find_ckpt(ckpt_path)) load_param_into_net(eval_net, param_dict) auc_metric = AUCMetric() eval_model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) eval_callback = EvalCallBack(eval_model, ds_eval, auc_metric, config) eval_model.eval(ds_eval, callbacks=eval_callback)
def train_and_eval(config): """ test_train_eval """ data_path = config.data_path batch_size = config.batch_size epochs = config.epochs if config.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif config.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: dataset_type = DataType.H5 host_device_mix = bool(config.host_device_mix) sparse = config.sparse print("epochs is {}".format(epochs)) if config.full_batch: context.set_auto_parallel_context(full_batch=True) ds.config.set_seed(1) if config.field_slice: compute_manual_shape(config, get_group_size()) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type, manual_shape=config.manual_shape, target_column=config.field_size) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type, manual_shape=config.manual_shape, target_column=config.field_size) else: ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size * get_group_size(), data_type=dataset_type) else: ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) train_net.set_train() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) # Save strategy ckpts according to the rank id, this must be done before initializing the callbacks. config.stra_ckpt = os.path.join( config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config, per_print_times=20) ckptconfig = CheckpointConfig( save_checkpoint_steps=ds_train.get_dataset_size() * epochs, keep_checkpoint_max=5, integrated_save=False) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=os.path.join( config.ckpt_path, 'ckpt_' + str(get_rank())), config=ckptconfig) context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) callback_list = [ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback ] if not host_device_mix: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not sparse))