Пример #1
0
 def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
     """PS embeddingLookup cache enable set."""
     self.embedding_table.cache_enable = True
     self.embedding_table.is_param_ps = True
     _set_cache_enable(True)
     if _is_role_worker():
         _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
Пример #2
0
 def _process_vocab_cache(self, slice_mode):
     """PS embeddingLookup cache check and process."""
     self.cache_enable = False
     if self.vocab_cache_size > 0:
         if self.target == 'CPU':
             logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
                            "current target is CPU, so it will be ignored.")
             return
         enable_ps = _get_ps_context("enable_ps")
         if not enable_ps:
             logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
                            "mode, current mode is not parameter server trainning mode, so it will be ignored.")
             return
         parallel_mode = _get_parallel_mode()
         is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
         if is_auto_parallel:
             device_num = get_group_size()
             full_batch = _get_full_batch()
             if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"):
                 raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
                                  "in 'full_batch' and 'table_row_slice' parallel strategy.")
             self.vocab_cache_size = self.vocab_cache_size * device_num
         self.cache_enable = True
         if _is_role_worker():
             self.vocab_size = self.vocab_cache_size
Пример #3
0
 def _process_vocab_cache(self, slice_mode):
     """PS embeddingLookup cache check and process."""
     self.cache_enable = False
     if self.vocab_cache_size > 0:
         if self.target == 'CPU':
             logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
                            "current target is CPU, so it will be ignored.")
             return
         enable_ps = _get_ps_context("enable_ps")
         if not enable_ps:
             logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
                            "mode, current mode is not parameter server trainning mode, so it will be ignored.")
             return
         parallel_mode = _get_parallel_mode()
         is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
         if is_auto_parallel:
             rank_size = get_group_size()
             rank_id = get_rank()
             full_batch = _get_full_batch()
             if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
                 raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
                                  "in 'full_batch' and 'table_row_slice' parallel strategy.")
             self.vocab_cache_size = self.vocab_cache_size * rank_size
             _set_rank_id(rank_id)
         self.cache_enable = True
         if _is_role_worker():
             self.vocab_size = self.vocab_cache_size
             if context.get_context("enable_sparse") != self.sparse:
                 raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
                                  "kernels and equal the value of 'enable_sparse' in context setting in "
                                  "parameter server cache mode")
Пример #4
0
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    if cache_enable:
        config.full_batch = True
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        ds.config.set_seed(1)
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size * get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size * get_group_size(),
                                 data_type=dataset_type)
    else:
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size,
                                  rank_id=get_rank(),
                                  rank_size=get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size,
                                 rank_id=get_rank(),
                                 rank_size=get_group_size(),
                                 data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    if cache_enable:
        config.stra_ckpt = os.path.join(
            config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
        context.set_auto_parallel_context(
            strategy_ckpt_save_file=config.stra_ckpt)

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    if _is_role_worker():
        if cache_enable:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
                keep_checkpoint_max=1,
                integrated_save=False)
        else:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size(),
                keep_checkpoint_max=5)
    else:
        ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
                                      keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=bool(parameter_server and cache_enable))
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    cache_enable = config.vocab_cache_size > 0
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              data_type=dataset_type)
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
    callback = LossCallBack(config=config)
    if _is_role_worker():
        if cache_enable:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
                keep_checkpoint_max=1)
        else:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size(),
                keep_checkpoint_max=5)
    else:
        ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
                                      keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path,
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback,
        ckpoint_cb
    ]

    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(parameter_server and cache_enable))
Пример #6
0
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 param_init='normal',
                 target='CPU',
                 slice_mode='batch_slice',
                 manual_shapes=None,
                 max_norm=None,
                 sparse=True,
                 vocab_cache_size=0):
        super(EmbeddingLookup, self).__init__()
        validator.check_value_type('sparse', sparse, [bool], self.cls_name)
        self.target = target
        if target not in ('CPU', 'DEVICE'):
            raise ValueError(
                'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
                str(target) +
                ', should be one of values in \'CPU\', \'DEVICE\'.')
        if not sparse and target == 'CPU':
            raise ValueError(
                'When target is CPU, embedding_lookup must be sparse.')
        enable_ps = context.get_ps_context("enable_ps")
        if not enable_ps and vocab_cache_size > 0:
            logger.warning(
                "The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
                "current mode is not parameter server trainning mode, so it will be ignored."
            )
            vocab_cache_size = 0
        if sparse:
            self.gatherv2 = P.SparseGatherV2()
        else:
            self.gatherv2 = P.GatherV2()
        self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
            'primitive_target', 'CPU')
        self.vocab_size = validator.check_positive_int(vocab_size,
                                                       'vocab_size')
        self.vocab_cache_size = validator.check_non_negative_int(
            vocab_cache_size, 'vocab_cache_size')
        self.embedding_size = validator.check_positive_int(
            embedding_size, 'embedding_size')
        parallel_mode = _get_parallel_mode()
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                             ParallelMode.AUTO_PARALLEL)
        self.cache_enable = self.vocab_cache_size > 0
        if self.cache_enable:
            if is_auto_parallel:
                self.vocab_cache_size = self.vocab_cache_size * get_group_size(
                )
            self.vocab_size = self.vocab_cache_size

        self.embedding_table = Parameter(initializer(
            param_init, [self.vocab_size, self.embedding_size]),
                                         name='embedding_table')
        if self.cache_enable:
            self.embedding_table.cache_enable = True
            _set_cache_enable(True)
            if _is_role_worker():
                _insert_hash_table_size(self.embedding_table.name,
                                        vocab_cache_size, embedding_size,
                                        vocab_size)
        self.forward_unique = False
        self.gather_revert = P.GatherV2()
        self.unique = P.Unique().shard(((1, ), ))
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        indices_shape_size = 2
        if slice_mode == "field_slice" and is_auto_parallel:
            if not manual_shapes:
                raise ValueError(
                    "in slice field mode, the manual_shapes should not be none"
                )
            if not isinstance(manual_shapes, tuple):
                raise TypeError(
                    "manual_shapes type must be tuple(int) cannot be {}!".
                    format(type(manual_shapes)))
            for dim in manual_shapes:
                validator.check_positive_int(dim, 'manual shape dim',
                                             self.cls_name)
            self.gatherv2.add_prim_attr("manual_split", manual_shapes)
            self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
            self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
            self.embeddinglookup.shard(
                ((get_group_size(), 1), (1, get_group_size())))
        elif slice_mode == "table_row_slice" and is_auto_parallel:
            if target == 'DEVICE' and not self.cache_enable:
                indices_shape_size = 1
                self.gather_revert.shard(((1, 1), (get_group_size(), )))
                self.forward_unique = True
            indices_strategy = (1, ) * indices_shape_size
            self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
            self.embeddinglookup.shard(
                ((get_group_size(), 1), indices_strategy))
        elif slice_mode == "table_column_slice" and is_auto_parallel:
            if target == 'DEVICE':
                indices_shape_size = 1
                self.gather_revert.shard(((1, get_group_size()), (1, )))
                self.forward_unique = True
            indices_strategy = (1, ) * indices_shape_size
            self.gatherv2.shard(((1, get_group_size()), indices_strategy))
            self.embeddinglookup.shard(
                ((1, get_group_size()), indices_strategy))
        elif slice_mode == "batch_slice" and is_auto_parallel:
            indices_strategy = [get_group_size()]
            indices_strategy.extend([1] * (indices_shape_size - 1))
            indices_strategy = tuple(indices_strategy)
            self.gatherv2.shard(((1, 1), indices_strategy))
            self.embeddinglookup.shard(((1, 1), indices_strategy))
        else:
            if is_auto_parallel:
                raise ValueError(
                    "slice_mode should support mode in nn.EmbeddingLookup, but get "
                    + str(slice_mode))
        self.embedding_table.unique = self.forward_unique
        self.max_norm = max_norm
        if self.max_norm is not None:
            self.max_norm = validator.check_positive_float(
                self.max_norm, 'max_norm', self.cls_name)
            self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
Пример #7
0
    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    net_with_criterion = WithLossCell(net, criterion)
    train_network = TrainOneStepCell(net_with_criterion, optimizer)
    train_network.set_train()
    losses = []
    for _ in range(epoch):
        data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
        label = Tensor(np.random.randint(0, 9, (32), np.int32))
        if _is_role_pserver():
            train_network(data, label)
            sys.exit()
        else:
            loss = train_network(data, label).asnumpy()
            losses.append(loss)
    print(losses)
    return losses


envs = os.environ
if __name__ == "__main__":
    set_seed(0)
    ps_loss = do_sparse_embedding(True)

    if _is_role_worker():
        context.reset_ps_context()
        set_seed(0)
        no_ps_loss = do_sparse_embedding()
        context.set_ps_context(enable_ps=True)

    assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)