Ejemplo n.º 1
0
 def test_batcher():
     (a,b),(a_size,b_size)=split_data(None, h5_path, overwrite_previous=False,shuffle=False)
     return batch_data(b,
         normalizer_fun=lambda x: x,
         transformer_fun=transformer_fun,
         flatten=True,
         batch_size=batch_size)
Ejemplo n.º 2
0
def split_and_batch(data_loader, 
                    batch_size, 
                    doclength,
                    h5_path,
                    rng_seed=888,
                    normalizer_fun=data_utils.normalize,
                    transformer_fun=data_utils.to_one_hot):
    """
    Convenience wrapper for most common splitting and batching
    workflow in neon. Splits data to an HDF5 path, if it does not already exist,
    and then returns functions for getting geerators over the datasets
    (gets around limitations of input to neon_utils.DiskDataIterator)
    """
    data_batches = batch_data(data_loader, batch_size,
        normalizer_fun=normalizer_fun,
        transformer_fun=None)
    (_, _), (train_size, test_size) = split_data(data_batches, 
            h5_path, overwrite_previous=False, rng_seed=rng_seed)
    def train_batcher():
        (a,b),(a_size,b_size)=split_data(None, h5_path=h5_path, overwrite_previous=False, shuffle=True)
        return batch_data(a,
            normalizer_fun=lambda x: x,
            transformer_fun=transformer_fun,
            flatten=True,
            batch_size=batch_size)
    def test_batcher():
        (a,b),(a_size,b_size)=split_data(None, h5_path, overwrite_previous=False,shuffle=False)
        return batch_data(b,
            normalizer_fun=lambda x: x,
            transformer_fun=transformer_fun,
            flatten=True,
            batch_size=batch_size)

    return (train_batcher, test_batcher), (train_size, test_size)               
Ejemplo n.º 3
0
            shuffle=True)
        import sys

        # get a record
        next_text, next_label = next(iter(amtr))

        try:
            print "Next record shape: {}".format(next_text.shape)
        except AttributeError as e:
            print "(No shape) Text: '{}'".format(next_text)


        # batch training, testing sets
        am_train_batch = batch_data(amtr,
            normalizer_fun=lambda x: data_utils.normalize(x, 
                max_length=300, 
                truncate_left=True),
            transformer_fun=None)
        am_test_batch = batch_data(amte,
            normalizer_fun=None,transformer_fun=None)

        # Spit out some sample data
        next_batch = am_train_batch.next()
        data, label = next_batch
        np.set_printoptions(threshold=np.nan)
        print "Batch properties:"
        print "Shape (data): {}".format(data.shape)
        print "Shape (label): {}".format(label.shape)
        print "Type: {}".format(type(data))
        print
        print "First record of first batch:"
Ejemplo n.º 4
0
 def __iter__(self):
     return batch_data(*self.reset_args, **self.reset_kwargs)
Ejemplo n.º 5
0
    file2 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\train-labels.idx1-ubyte'
    file3 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\t10k-images.idx3-ubyte'
    file4 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\t10k-labels.idx1-ubyte'
    train_data, train_data_head = loadImageSet(file1)
    train_labels, train_labels_head = loadLabelSet(file2)
    test_data, test_data_head = loadImageSet(file3)
    test_labels, test_labels_head = loadLabelSet(file4)

    #为每个客户端对训练数据进行处理和批处理
    clients = create_clients(train_data,
                             train_labels,
                             num_clients=1,
                             initial='client')
    clients_batched = dict()
    for (client_name, data) in clients.items():
        clients_batched[client_name] = batch_data(data)

    #对测试集进行处理和批处理
    test_batched = tf.data.Dataset.from_tensor_slices(
        (test_data, test_labels)).batch(len(test_labels))

    #为模型编译定义一个优化器、损失函数和度量
    lr = 0.01
    comms_round = 100
    loss = 'categorical_crossentropy'
    metrics = ['accuracy']
    optimizer = SGD(lr=lr, decay=lr / comms_round, momentum=0.9)

    #初始化全局模型
    smlp_global = SimpleMLP()
    global_model = smlp_global.build(784, 10)