示例#1
0
def test_deepfm():
    data_config = DataConfig()
    train_config = TrainConfig()
    device_id = int(os.getenv('DEVICE_ID'))
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    rank_size = None
    rank_id = None

    dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
    print("dataset_path:", dataset_path)
    ds_train = create_dataset(dataset_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=train_config.batch_size,
                              data_type=DataType(data_config.data_format),
                              rank_size=rank_size,
                              rank_id=rank_id)

    model_builder = ModelBuilder(ModelConfig, TrainConfig)
    train_net, eval_net = model_builder.get_train_eval_net()
    auc_metric = AUCMetric()
    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    loss_file_name = './loss.log'
    time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
    loss_callback = LossCallBack(loss_file_path=loss_file_name)
    callback_list = [time_callback, loss_callback]

    eval_file_name = './auc.log'
    ds_eval = create_dataset(dataset_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=train_config.batch_size,
                             data_type=DataType(data_config.data_format))
    eval_callback = EvalCallBack(model,
                                 ds_eval,
                                 auc_metric,
                                 eval_file_path=eval_file_name)
    callback_list.append(eval_callback)

    print("train_config.train_epochs:", train_config.train_epochs)
    model.train(train_config.train_epochs, ds_train, callbacks=callback_list)

    export_loss_value = 0.51
    print("loss_callback.loss:", loss_callback.loss)
    assert loss_callback.loss < export_loss_value
    export_per_step_time = 40.0
    print("time_callback:", time_callback.per_step_time)
    assert time_callback.per_step_time < export_per_step_time
    print("*******test case pass!********")
示例#2
0
    def get_callback_list(self, model=None, eval_dataset=None):
        """
        Get callbacks which contains checkpoint callback, eval callback and loss callback.

        Args:
            model (Cell): The network is added callback (default=None)
            eval_dataset (Dataset): Dataset for eval (default=None)
        """
        callback_list = []
        if self.train_config.save_checkpoint:
            config_ck = CheckpointConfig(
                save_checkpoint_steps=self.train_config.save_checkpoint_steps,
                keep_checkpoint_max=self.train_config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(
                prefix=self.train_config.ckpt_file_name_prefix,
                directory=self.train_config.output_path,
                config=config_ck)
            callback_list.append(ckpt_cb)
        if self.train_config.eval_callback:
            if model is None:
                raise RuntimeError(
                    "train_config.eval_callback is {}; get_callback_list() args model is {}"
                    .format(self.train_config.eval_callback, model))
            if eval_dataset is None:
                raise RuntimeError(
                    "train_config.eval_callback is {}; get_callback_list() args eval_dataset is {}"
                    .format(self.train_config.eval_callback, eval_dataset))
            auc_metric = AUCMetric()
            eval_callback = EvalCallBack(model,
                                         eval_dataset,
                                         auc_metric,
                                         eval_file_path=os.path.join(
                                             self.train_config.output_path,
                                             self.train_config.eval_file_name))
            callback_list.append(eval_callback)
        if self.train_config.loss_callback:
            loss_callback = LossCallBack(
                loss_file_path=os.path.join(self.train_config.output_path,
                                            self.train_config.loss_file_name))
            callback_list.append(loss_callback)
        if callback_list:
            return callback_list

        return None
示例#3
0
文件: train.py 项目: yrpang/mindspore
    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
    loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
    callback_list = [time_callback, loss_callback]

    if train_config.save_checkpoint:
        if rank_size:
            train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
            args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
        if args_opt.device_target != "Ascend":
            config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
                                         keep_checkpoint_max=train_config.keep_checkpoint_max)
        else:
            config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
                                         keep_checkpoint_max=train_config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
                                  directory=args_opt.ckpt_path,
                                  config=config_ck)
        callback_list.append(ckpt_cb)

    if args_opt.do_eval:
        ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
                                 epochs=1,
                                 batch_size=train_config.batch_size,
                                 data_type=DataType(data_config.data_format))
        eval_callback = EvalCallBack(model, ds_eval, auc_metric,
                                     eval_file_path=args_opt.eval_file_name)
        callback_list.append(eval_callback)
    model.train(train_config.train_epochs, ds_train, callbacks=callback_list)