コード例 #1
0
 def test_batcher():
     (a,b),(a_size,b_size)=split_data(None, h5_path, overwrite_previous=False,shuffle=False)
     return batch_data(b,
         normalizer_fun=lambda x: x,
         transformer_fun=transformer_fun,
         flatten=True,
         batch_size=batch_size)
コード例 #2
0
def split_and_batch(data_loader, 
                    batch_size, 
                    doclength,
                    h5_path,
                    rng_seed=888,
                    normalizer_fun=data_utils.normalize,
                    transformer_fun=data_utils.to_one_hot):
    """
    Convenience wrapper for most common splitting and batching
    workflow in neon. Splits data to an HDF5 path, if it does not already exist,
    and then returns functions for getting geerators over the datasets
    (gets around limitations of input to neon_utils.DiskDataIterator)
    """
    data_batches = batch_data(data_loader, batch_size,
        normalizer_fun=normalizer_fun,
        transformer_fun=None)
    (_, _), (train_size, test_size) = split_data(data_batches, 
            h5_path, overwrite_previous=False, rng_seed=rng_seed)
    def train_batcher():
        (a,b),(a_size,b_size)=split_data(None, h5_path=h5_path, overwrite_previous=False, shuffle=True)
        return batch_data(a,
            normalizer_fun=lambda x: x,
            transformer_fun=transformer_fun,
            flatten=True,
            batch_size=batch_size)
    def test_batcher():
        (a,b),(a_size,b_size)=split_data(None, h5_path, overwrite_previous=False,shuffle=False)
        return batch_data(b,
            normalizer_fun=lambda x: x,
            transformer_fun=transformer_fun,
            flatten=True,
            batch_size=batch_size)

    return (train_batcher, test_batcher), (train_size, test_size)               
コード例 #3
0
            shuffle=True)
        import sys

        # get a record
        next_text, next_label = next(iter(amtr))

        try:
            print "Next record shape: {}".format(next_text.shape)
        except AttributeError as e:
            print "(No shape) Text: '{}'".format(next_text)


        # batch training, testing sets
        am_train_batch = batch_data(amtr,
            normalizer_fun=lambda x: data_utils.normalize(x, 
                max_length=300, 
                truncate_left=True),
            transformer_fun=None)
        am_test_batch = batch_data(amte,
            normalizer_fun=None,transformer_fun=None)

        # Spit out some sample data
        next_batch = am_train_batch.next()
        data, label = next_batch
        np.set_printoptions(threshold=np.nan)
        print "Batch properties:"
        print "Shape (data): {}".format(data.shape)
        print "Shape (label): {}".format(label.shape)
        print "Type: {}".format(type(data))
        print
        print "First record of first batch:"
コード例 #4
0
 def __iter__(self):
     return batch_data(*self.reset_args, **self.reset_kwargs)
コード例 #5
0
    file2 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\train-labels.idx1-ubyte'
    file3 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\t10k-images.idx3-ubyte'
    file4 = r'C:\Users\张立才\PycharmProjects\pythonProject\venv\data\t10k-labels.idx1-ubyte'
    train_data, train_data_head = loadImageSet(file1)
    train_labels, train_labels_head = loadLabelSet(file2)
    test_data, test_data_head = loadImageSet(file3)
    test_labels, test_labels_head = loadLabelSet(file4)

    #为每个客户端对训练数据进行处理和批处理
    clients = create_clients(train_data,
                             train_labels,
                             num_clients=1,
                             initial='client')
    clients_batched = dict()
    for (client_name, data) in clients.items():
        clients_batched[client_name] = batch_data(data)

    #对测试集进行处理和批处理
    test_batched = tf.data.Dataset.from_tensor_slices(
        (test_data, test_labels)).batch(len(test_labels))

    #为模型编译定义一个优化器、损失函数和度量
    lr = 0.01
    comms_round = 100
    loss = 'categorical_crossentropy'
    metrics = ['accuracy']
    optimizer = SGD(lr=lr, decay=lr / comms_round, momentum=0.9)

    #初始化全局模型
    smlp_global = SimpleMLP()
    global_model = smlp_global.build(784, 10)