コード例 #1
0
ファイル: train.py プロジェクト: zimaxeg/mindspore
def create_train_dataset(batch_size):
    """create train dataset"""
    # apply repeat operations
    repeat_count = bert_train_cfg.epoch_size
    ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
                           columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
                                         "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
    type_cast_op = deMap.TypeCastOp("int32")
    ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
    ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
    ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
    ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
    ds = ds.map(input_columns="input_mask", operations=type_cast_op)
    ds = ds.map(input_columns="input_ids", operations=type_cast_op)
    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.repeat(repeat_count)
    return ds
コード例 #2
0
def me_de_train_dataset():
    """test me de train dataset"""
    # apply repeat operations
    repeat_count = 1
    ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
                                                               "next_sentence_labels", "masked_lm_positions",
                                                               "masked_lm_ids", "masked_lm_weights"])
    type_cast_op = deMap.TypeCastOp("int32")
    ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
    ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
    ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
    ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
    ds = ds.map(input_columns="input_mask", operations=type_cast_op)
    ds = ds.map(input_columns="input_ids", operations=type_cast_op)
    # apply batch operations
    batch_size = 16
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.repeat(repeat_count)
    return ds