def test_deepfm(): data_config = DataConfig() train_config = TrainConfig() device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) rank_size = None rank_id = None dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/" print("dataset_path:", dataset_path) ds_train = create_dataset(dataset_path, train_mode=True, epochs=1, batch_size=train_config.batch_size, data_type=DataType(data_config.data_format), rank_size=rank_size, rank_id=rank_id) model_builder = ModelBuilder(ModelConfig, TrainConfig) train_net, eval_net = model_builder.get_train_eval_net() auc_metric = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) loss_file_name = './loss.log' time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) loss_callback = LossCallBack(loss_file_path=loss_file_name) callback_list = [time_callback, loss_callback] eval_file_name = './auc.log' ds_eval = create_dataset(dataset_path, train_mode=False, epochs=1, batch_size=train_config.batch_size, data_type=DataType(data_config.data_format)) eval_callback = EvalCallBack(model, ds_eval, auc_metric, eval_file_path=eval_file_name) callback_list.append(eval_callback) print("train_config.train_epochs:", train_config.train_epochs) model.train(train_config.train_epochs, ds_train, callbacks=callback_list) export_loss_value = 0.51 print("loss_callback.loss:", loss_callback.loss) assert loss_callback.loss < export_loss_value export_per_step_time = 40.0 print("time_callback:", time_callback.per_step_time) assert time_callback.per_step_time < export_per_step_time print("*******test case pass!********")
def get_callback_list(self, model=None, eval_dataset=None): """ Get callbacks which contains checkpoint callback, eval callback and loss callback. Args: model (Cell): The network is added callback (default=None) eval_dataset (Dataset): Dataset for eval (default=None) """ callback_list = [] if self.train_config.save_checkpoint: config_ck = CheckpointConfig( save_checkpoint_steps=self.train_config.save_checkpoint_steps, keep_checkpoint_max=self.train_config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint( prefix=self.train_config.ckpt_file_name_prefix, directory=self.train_config.output_path, config=config_ck) callback_list.append(ckpt_cb) if self.train_config.eval_callback: if model is None: raise RuntimeError( "train_config.eval_callback is {}; get_callback_list() args model is {}" .format(self.train_config.eval_callback, model)) if eval_dataset is None: raise RuntimeError( "train_config.eval_callback is {}; get_callback_list() args eval_dataset is {}" .format(self.train_config.eval_callback, eval_dataset)) auc_metric = AUCMetric() eval_callback = EvalCallBack(model, eval_dataset, auc_metric, eval_file_path=os.path.join( self.train_config.output_path, self.train_config.eval_file_name)) callback_list.append(eval_callback) if self.train_config.loss_callback: loss_callback = LossCallBack( loss_file_path=os.path.join(self.train_config.output_path, self.train_config.loss_file_name)) callback_list.append(loss_callback) if callback_list: return callback_list return None
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name) callback_list = [time_callback, loss_callback] if train_config.save_checkpoint: if rank_size: train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank()) args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/') if args_opt.device_target != "Ascend": config_ck = CheckpointConfig(save_checkpoint_steps=steps_size, keep_checkpoint_max=train_config.keep_checkpoint_max) else: config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, keep_checkpoint_max=train_config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, directory=args_opt.ckpt_path, config=config_ck) callback_list.append(ckpt_cb) if args_opt.do_eval: ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, epochs=1, batch_size=train_config.batch_size, data_type=DataType(data_config.data_format)) eval_callback = EvalCallBack(model, ds_eval, auc_metric, eval_file_path=args_opt.eval_file_name) callback_list.append(eval_callback) model.train(train_config.train_epochs, ds_train, callbacks=callback_list)