コード例 #1
0
ファイル: datasets.py プロジェクト: tweikiang/mindspore-1
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
                    line_per_sample=1000, rank_size=None, rank_id=None):
    """
    get_tf_dataset
    """
    dataset_files = []
    file_prefix_name = 'train' if train_mode else 'test'
    shuffle = train_mode
    for (dirpath, _, filenames) in os.walk(data_dir):
        for filename in filenames:
            if file_prefix_name in filename and "tfrecord" in filename:
                dataset_files.append(os.path.join(dirpath, filename))
    schema = de.Schema()
    schema.add_column('feat_ids', de_type=mstype.int32)
    schema.add_column('feat_vals', de_type=mstype.float32)
    schema.add_column('label', de_type=mstype.float32)
    if rank_size is not None and rank_id is not None:
        ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
                                num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
    else:
        ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8)
    ds = ds.batch(int(batch_size / line_per_sample),
                  drop_remainder=True)
    ds = ds.map(operations=(lambda x, y, z: (
        np.array(x).flatten().reshape(batch_size, 39),
        np.array(y).flatten().reshape(batch_size, 39),
        np.array(z).flatten().reshape(batch_size, 1))),
                input_columns=['feat_ids', 'feat_vals', 'label'],
                columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
    #if train_mode:
    ds = ds.repeat(epochs)
    return ds
コード例 #2
0
def _get_tf_dataset(directory,
                    train_mode=True,
                    epochs=1,
                    batch_size=1000,
                    line_per_sample=1000,
                    rank_size=None,
                    rank_id=None):
    """
    Get dataset with tfrecord format.

    Args:
        directory (str): Dataset directory.
        train_mode (bool): Whether dataset is use for train or eval (default=True).
        epochs (int): Dataset epoch size (default=1).
        batch_size (int): Dataset batch size (default=1000).
        line_per_sample (int): The number of sample per line (default=1000).
        rank_size (int): The number of device, not necessary for single device (default=None).
        rank_id (int): Id of device, not necessary for single device (default=None).

    Returns:
        Dataset.
    """
    dataset_files = []
    file_prefixt_name = 'train' if train_mode else 'test'
    shuffle = train_mode
    for (dir_path, _, filenames) in os.walk(directory):
        for filename in filenames:
            if file_prefixt_name in filename and 'tfrecord' in filename:
                dataset_files.append(os.path.join(dir_path, filename))
    schema = de.Schema()
    schema.add_column('feat_ids', de_type=mstype.int32)
    schema.add_column('feat_vals', de_type=mstype.float32)
    schema.add_column('label', de_type=mstype.float32)
    if rank_size is not None and rank_id is not None:
        ds = de.TFRecordDataset(dataset_files=dataset_files,
                                shuffle=shuffle,
                                schema=schema,
                                num_parallel_workers=8,
                                num_shards=rank_size,
                                shard_id=rank_id,
                                shard_equal_rows=True,
                                num_samples=3000)
    else:
        ds = de.TFRecordDataset(dataset_files=dataset_files,
                                shuffle=shuffle,
                                schema=schema,
                                num_parallel_workers=8,
                                num_samples=3000)
    ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
    ds = ds.map(operations=(lambda x, y, z:
                            (np.array(x).flatten().reshape(batch_size, 39),
                             np.array(y).flatten().reshape(batch_size, 39),
                             np.array(z).flatten().reshape(batch_size, 1))),
                input_columns=['feat_ids', 'feat_vals', 'label'],
                column_order=['feat_ids', 'feat_vals', 'label'],
                num_parallel_workers=8)
    ds = ds.repeat(epochs)
    return ds
コード例 #3
0
def _get_tf_dataset(data_dir,
                    schema_dict,
                    input_shape_dict,
                    train_mode=True,
                    epochs=1,
                    batch_size=4096,
                    line_per_sample=4096,
                    rank_size=None,
                    rank_id=None):
    """
    _get_tf_dataset
    """
    dataset_files = []
    file_prefix_name = 'train' if train_mode else 'eval'
    shuffle = bool(train_mode)
    for (dirpath, _, filenames) in os.walk(data_dir):
        for filename in filenames:
            if file_prefix_name in filename and "tfrecord" in filename:
                dataset_files.append(os.path.join(dirpath, filename))
    schema = de.Schema()

    float_key_list = ["label", "continue_val"]

    columns_list = []
    for key, attr_dict in schema_dict.items():
        print("key: {}; shape: {}".format(key, attr_dict["tf_shape"]))
        columns_list.append(key)
        if key in set(float_key_list):
            ms_dtype = mstype.float32
        else:
            ms_dtype = mstype.int32
        schema.add_column(key, de_type=ms_dtype)

    if rank_size is not None and rank_id is not None:
        ds = de.TFRecordDataset(dataset_files=dataset_files,
                                shuffle=shuffle,
                                schema=schema,
                                num_parallel_workers=8,
                                num_shards=rank_size,
                                shard_id=rank_id,
                                shard_equal_rows=True)
    else:
        ds = de.TFRecordDataset(dataset_files=dataset_files,
                                shuffle=shuffle,
                                schema=schema,
                                num_parallel_workers=8)
    ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)

    operations_list = []
    for key in columns_list:
        operations_list.append(
            lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
    print("ssssssssssssssssssssss---------------------" * 10)
    print(input_shape_dict)
    print("---------------------" * 10)
    print(schema_dict)

    def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
        a = np.asarray(a.reshape(batch_size, ))
        b = np.array(b).flatten().reshape(batch_size, -1)
        c = np.array(c).flatten().reshape(batch_size, -1)
        d = np.array(d).flatten().reshape(batch_size, -1)
        e = np.array(e).flatten().reshape(batch_size, -1)

        f = np.array(f).flatten().reshape(batch_size, -1)
        g = np.array(g).flatten().reshape(batch_size, -1)
        h = np.array(h).flatten().reshape(batch_size, -1)
        i = np.array(i).flatten().reshape(batch_size, -1)
        j = np.array(j).flatten().reshape(batch_size, -1)

        k = np.array(k).flatten().reshape(batch_size, -1)
        l = np.array(l).flatten().reshape(batch_size, -1)
        m = np.array(m).flatten().reshape(batch_size, -1)
        n = np.array(n).flatten().reshape(batch_size, -1)
        o = np.array(o).flatten().reshape(batch_size, -1)

        p = np.array(p).flatten().reshape(batch_size, -1)
        q = np.array(q).flatten().reshape(batch_size, -1)
        r = np.array(r).flatten().reshape(batch_size, -1)
        s = np.array(s).flatten().reshape(batch_size, -1)
        t = np.array(t).flatten().reshape(batch_size, -1)

        u = np.array(u).flatten().reshape(batch_size, -1)
        return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u

    ds = ds.map(
        operations=mixup,
        input_columns=[
            'label', 'continue_val', 'indicator_id', 'emb_128_id',
            'emb_64_single_id', 'multi_doc_ad_category_id',
            'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id',
            'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id',
            'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id',
            'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id',
            'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id',
            'multi_doc_ad_topic_id_mask', 'ad_id', 'display_ad_and_is_leak',
            'display_id', 'is_leak'
        ],
        columns_order=[
            'label', 'continue_val', 'indicator_id', 'emb_128_id',
            'emb_64_single_id', 'multi_doc_ad_category_id',
            'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id',
            'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id',
            'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id',
            'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id',
            'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id',
            'multi_doc_ad_topic_id_mask', 'display_id', 'ad_id',
            'display_ad_and_is_leak', 'is_leak'
        ],
        num_parallel_workers=8)

    ds = ds.repeat(epochs)
    return ds
コード例 #4
0
ファイル: dataset.py プロジェクト: lswzjuer/mindspore_hccr
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32

    Returns:
        dataset
    """
    device_num = int(os.getenv("RANK_SIZE"))
    rank_id = int(os.getenv("DEVICE_ID"))

    data_path_list = []
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            data_path_list.append(os.path.join(dataset_path, file))

    schema = dt.Schema()
    schema.add_column('image', de_type=mstype.uint8,
                      shape=[112, 112,
                             1])  # Binary data usually use uint8 here.
    schema.add_column('label', de_type=mstype.int64, shape=[])

    if device_num == 1:
        ds = dt.TFRecordDataset(data_path_list,
                                num_parallel_workers=4,
                                shuffle=True,
                                schema=schema,
                                shard_equal_rows=True)
    else:
        ds = dt.TFRecordDataset(data_path_list,
                                num_parallel_workers=4,
                                shuffle=True,
                                num_shards=device_num,
                                shard_id=rank_id,
                                schema=schema,
                                shard_equal_rows=True)

    resize_height = config.image_height
    resize_width = config.image_width
    rescale = 1.0 / 255.0
    shift = 0.0

    resize_op = C.Resize((resize_height, resize_width))
    rescale_op = C.Rescale(rescale, shift)
    change_swap_op = C.HWC2CHW()

    trans = []

    type_cast_op_train = C2.TypeCast(mstype.float32)
    trans += [resize_op, rescale_op, type_cast_op_train, change_swap_op]
    #trans += [resize_op, rescale_op, type_cast_op_train]
    #trans += [resize_op, type_cast_op_train, change_swap_op]
    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="label", operations=type_cast_op)
    ds = ds.map(input_columns="image", operations=trans)

    ds = ds.shuffle(buffer_size=config.buffer_size)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.repeat(repeat_num)

    return ds