Ejemplo n.º 1
0
    def split(self, data_inst, shuffle=True):
        header = data_inst.schema.get('header')

        data_sids_iter, data_size = collect_index(data_inst)

        data_sids = []
        for sid, _ in data_sids_iter:
            data_sids.append(sid)
        data_sids = np.array(data_sids)

        if shuffle:
            np.random.shuffle(data_sids)

        kf = sk_KFold(n_splits=self.n_splits)

        for train, test in kf.split(data_sids):
            train_sids = data_sids[train]
            test_sids = data_sids[test]
            train_sids_table = [(str(x), 1) for x in train_sids]
            test_sids_table = [(str(x), 1) for x in test_sids]
            # print(train_sids_table)
            train_table = eggroll.parallelize(train_sids_table,
                                              include_key=True,
                                              partition=data_inst._partitions)
            train_data = data_inst.join(train_table, lambda x, y: x)
            test_table = eggroll.parallelize(test_sids_table,
                                             include_key=True,
                                             partition=data_inst._partitions)
            test_data = data_inst.join(test_table, lambda x, y: x)
            train_data.schema['header'] = header
            test_data.schema['header'] = header
            yield train_data, test_data
Ejemplo n.º 2
0
 def test_collect_index(self):
     data_num = 100
     feature_num = 20
     data_instances = self.prepare_data(data_num=data_num, feature_num=feature_num)
     # res = data_instances.mapValues(lambda x: x)
     data_sids_iter, data_size = indices.collect_index(data_instances)
     self.assertEqual(data_num, data_size)
     real_index_num = 0
     for sid, _ in data_sids_iter:
         real_index_num += 1
     self.assertEqual(data_num, real_index_num)
Ejemplo n.º 3
0
    def __init_mini_batch_data_seperator(self, data_insts, batch_size, batch_strategy, masked_rate, shuffle):
        self.data_sids_iter, data_size = indices.collect_index(data_insts)

        self.batch_data_generator = get_batch_generator(
            data_size, batch_size, batch_strategy, masked_rate, shuffle=shuffle)
        self.batch_nums = self.batch_data_generator.batch_nums
        self.batch_mutable = self.batch_data_generator.batch_mutable()
        self.masked_batch_size = self.batch_data_generator.masked_batch_size

        if self.batch_mutable is False:
            self.__generate_batch_data()
Ejemplo n.º 4
0
    def __mini_batch_data_seperator(self, data_insts, batch_size):
        data_sids_iter, data_size = indices.collect_index(data_insts)

        if batch_size > data_size:
            batch_size = data_size
            self.batch_size = batch_size

        batch_nums = (data_size + batch_size - 1) // batch_size

        batch_data_sids = []
        curt_data_num = 0
        curt_batch = 0
        curt_batch_ids = []
        for sid, values in data_sids_iter:
            #print('sid is {}, values is {}'.format(sid, values))
            curt_batch_ids.append((sid, None))
            curt_data_num += 1
            if curt_data_num % batch_size == 0:
                curt_batch += 1
                if curt_batch < batch_nums:
                    batch_data_sids.append(curt_batch_ids)
                    curt_batch_ids = []
            if curt_data_num == data_size and len(curt_batch_ids) != 0:
                batch_data_sids.append(curt_batch_ids)

        self.batch_nums = len(batch_data_sids)

        all_batch_data = []
        all_index_data = []
        for index_data in batch_data_sids:
            # LOGGER.debug('in generator, index_data is {}'.format(index_data))
            index_table = eggroll.parallelize(index_data,
                                              include_key=True,
                                              partition=data_insts._partitions)
            batch_data = index_table.join(data_insts, lambda x, y: y)

            # yield batch_data
            all_batch_data.append(batch_data)
            all_index_data.append(index_table)
        self.all_batch_data = all_batch_data
        self.all_index_data = all_index_data
        return batch_data_sids
Ejemplo n.º 5
0
    def split(self, data_inst):
        np.random.seed(self.random_seed)

        header = data_inst.schema.get('header')

        data_sids_iter, data_size = collect_index(data_inst)
        data_sids = []
        key_type = None
        for sid, _ in data_sids_iter:
            if key_type is None:
                key_type = type(sid)
            data_sids.append(sid)
        data_sids = np.array(data_sids)
        if self.shuffle:
            np.random.shuffle(data_sids)

        kf = sk_KFold(n_splits=self.n_splits)

        n = 0
        for train, test in kf.split(data_sids):

            train_sids = data_sids[train]
            test_sids = data_sids[test]

            n += 1

            train_sids_table = [(key_type(x), 1) for x in train_sids]
            test_sids_table = [(key_type(x), 1) for x in test_sids]
            # print(train_sids_table)
            train_table = session.parallelize(train_sids_table,
                                              include_key=True,
                                              partition=data_inst._partitions)
            train_data = data_inst.join(train_table, lambda x, y: x)

            test_table = session.parallelize(test_sids_table,
                                             include_key=True,
                                             partition=data_inst._partitions)
            test_data = data_inst.join(test_table, lambda x, y: x)
            train_data.schema['header'] = header
            test_data.schema['header'] = header
            yield train_data, test_data
Ejemplo n.º 6
0
    def __mini_batch_data_seperator(self, data_insts, batch_size):
        data_sids_iter, data_size = indices.collect_index(data_insts)
        batch_nums = (data_size + batch_size - 1) // batch_size

        batch_data_sids = []
        curt_data_num = 0
        curt_batch = 0
        curt_batch_ids = []
        for sid, values in data_sids_iter:
            # print('sid is {}, values is {}'.format(sid, values))
            curt_batch_ids.append((sid, None))
            curt_data_num += 1
            if curt_data_num % batch_size == 0:
                curt_batch += 1
                if curt_batch < batch_nums:
                    batch_data_sids.append(curt_batch_ids)
                    curt_batch_ids = []
            if curt_data_num == data_size and len(curt_batch_ids) != 0:
                batch_data_sids.append(curt_batch_ids)

        self.batch_nums = len(batch_data_sids)

        return batch_data_sids