Esempio n. 1
0
def train(args_opt, config):
    if args_opt.run_distribute:
        init()
        context.set_auto_parallel_context(parallel_mode="data_parallel")

    ds = dataset_creator(args_opt.run_distribute)

    net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE,
                       config.FINAL_FEATURE_WIDTH)
    net.set_train(True)

    if config.CKPT_PATH != '':
        param_dict = load_checkpoint(config.CKPT_PATH)
        load_param_into_net(net, param_dict)
        print('parameters loaded!')
    else:
        print('train from scratch...')

    criterion = ctc_loss()
    opt = mindspore.nn.RMSProp(params=net.trainable_params(),
                               centered=True,
                               learning_rate=config.LR_PARA,
                               momentum=config.MOMENTUM,
                               loss_scale=config.LOSS_SCALE)

    net = WithLossCell(net, criterion)
    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        config.LOSS_SCALE, False)
    model = Model(net,
                  optimizer=opt,
                  loss_scale_manager=loss_scale_manager,
                  amp_level="O2")

    callback = LossCallBack()
    config_ck = CheckpointConfig(
        save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
        keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
    ckpoint_cb = ModelCheckpoint(prefix="CNNCTC",
                                 config=config_ck,
                                 directory=config.SAVE_PATH)

    if args_opt.run_distribute:
        if args_opt.device_id == 0:
            model.train(config.TRAIN_EPOCHS,
                        ds,
                        callbacks=[callback, ckpoint_cb],
                        dataset_sink_mode=False)
        else:
            model.train(config.TRAIN_EPOCHS,
                        ds,
                        callbacks=[callback],
                        dataset_sink_mode=False)
    else:
        model.train(config.TRAIN_EPOCHS,
                    ds,
                    callbacks=[callback, ckpoint_cb],
                    dataset_sink_mode=False)
Esempio n. 2
0
def test_deepfm():
    data_config = DataConfig()
    train_config = TrainConfig()
    device_id = int(os.getenv('DEVICE_ID'))
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    rank_size = None
    rank_id = None

    dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
    print("dataset_path:", dataset_path)
    ds_train = create_dataset(dataset_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=train_config.batch_size,
                              data_type=DataType(data_config.data_format),
                              rank_size=rank_size,
                              rank_id=rank_id)

    model_builder = ModelBuilder(ModelConfig, TrainConfig)
    train_net, eval_net = model_builder.get_train_eval_net()
    auc_metric = AUCMetric()
    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    loss_file_name = './loss.log'
    time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
    loss_callback = LossCallBack(loss_file_path=loss_file_name)
    callback_list = [time_callback, loss_callback]

    eval_file_name = './auc.log'
    ds_eval = create_dataset(dataset_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=train_config.batch_size,
                             data_type=DataType(data_config.data_format))
    eval_callback = EvalCallBack(model,
                                 ds_eval,
                                 auc_metric,
                                 eval_file_path=eval_file_name)
    callback_list.append(eval_callback)

    print("train_config.train_epochs:", train_config.train_epochs)
    model.train(train_config.train_epochs, ds_train, callbacks=callback_list)

    export_loss_value = 0.51
    print("loss_callback.loss:", loss_callback.loss)
    assert loss_callback.loss < export_loss_value
    export_per_step_time = 40.0
    print("time_callback:", time_callback.per_step_time)
    assert time_callback.per_step_time < export_per_step_time
    print("*******test case pass!********")
Esempio n. 3
0
    def get_callback_list(self, model=None, eval_dataset=None):
        """
        Get callbacks which contains checkpoint callback, eval callback and loss callback.

        Args:
            model (Cell): The network is added callback (default=None)
            eval_dataset (Dataset): Dataset for eval (default=None)
        """
        callback_list = []
        if self.train_config.save_checkpoint:
            config_ck = CheckpointConfig(
                save_checkpoint_steps=self.train_config.save_checkpoint_steps,
                keep_checkpoint_max=self.train_config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(
                prefix=self.train_config.ckpt_file_name_prefix,
                directory=self.train_config.output_path,
                config=config_ck)
            callback_list.append(ckpt_cb)
        if self.train_config.eval_callback:
            if model is None:
                raise RuntimeError(
                    "train_config.eval_callback is {}; get_callback_list() args model is {}"
                    .format(self.train_config.eval_callback, model))
            if eval_dataset is None:
                raise RuntimeError(
                    "train_config.eval_callback is {}; get_callback_list() args eval_dataset is {}"
                    .format(self.train_config.eval_callback, eval_dataset))
            auc_metric = AUCMetric()
            eval_callback = EvalCallBack(model,
                                         eval_dataset,
                                         auc_metric,
                                         eval_file_path=os.path.join(
                                             self.train_config.output_path,
                                             self.train_config.eval_file_name))
            callback_list.append(eval_callback)
        if self.train_config.loss_callback:
            loss_callback = LossCallBack(
                loss_file_path=os.path.join(self.train_config.output_path,
                                            self.train_config.loss_file_name))
            callback_list.append(loss_callback)
        if callback_list:
            return callback_list

        return None
Esempio n. 4
0
                              batch_size=train_config.batch_size,
                              data_type=DataType(data_config.data_format),
                              rank_size=rank_size,
                              rank_id=rank_id)

    steps_size = ds_train.get_dataset_size()

    if model_config.convert_dtype:
        model_config.convert_dtype = args_opt.device_target != "CPU"
    model_builder = ModelBuilder(model_config, train_config)
    train_net, eval_net = model_builder.get_train_eval_net()
    auc_metric = AUCMetric()
    model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})

    time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
    loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
    callback_list = [time_callback, loss_callback]

    if train_config.save_checkpoint:
        if rank_size:
            train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
            args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
        if args_opt.device_target != "Ascend":
            config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
                                         keep_checkpoint_max=train_config.keep_checkpoint_max)
        else:
            config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
                                         keep_checkpoint_max=train_config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
                                  directory=args_opt.ckpt_path,
                                  config=config_ck)