Example #1
0
def local_train(initial_model, learning_rate, all_batches):
    # Mapping function to apply to each batch.
    @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model, batch):
        return batch_train(model, batch, learning_rate)

    return tff.sequence_reduce(all_batches, initial_model, batch_fn)
Example #2
0
def local_train(initial_model, learning_rate, all_batches):

    # Mapping function to apply to each batch.
    # batch_fn改成tf_computation来包装,是可以正常运行的
    @tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model, batch):
        learning_rate = 0.1
        return batch_train(model, batch, learning_rate)

    return tff.sequence_reduce(all_batches, initial_model, batch_fn)
Example #3
0
def local_train(initial_model, learning_rate, all_batches):
    # A funcao abaixo sera aplicada a cada lote. Essa funcao e criada pq a funcao batch_train necessita
    # de learning_rate como parametro
    #md = initial_model

    @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model_, batch):
        return batch_train(model_, batch, learning_rate)

    #for _ in range(1):
    #md = tff.sequence_reduce(all_batches, md, batch_fn)
    return tff.sequence_reduce(all_batches, initial_model, batch_fn)
def local_train(initial_model, lr, all_batchs):
    #每个batch数据都使用MAP函数进行计算
    @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model, batch):
        #由于batch_train函数接收的是三个参数,而tff.sequence_reduce处理的是两个参数
        #因此将batch_train嵌入到该函数中,由外函数提供学习率lr
        return batch_train(model, batch, lr)

    #模型训练,将所有客户端数据,逐个batch的调用上述的batch_train函数进行训练
    #使用的是batch_train中的SGD梯度下降方法,
    # 当对所有batch数据训练完成,也就对整个客户端数据训练完成
    #tff.sequence_reduce函数应用于联邦计算函数中,tff.sequence_reduce内不能包含tf代码
    return tff.sequence_reduce(all_batchs, initial_model, batch_fn)
Example #5
0
 def train_on_one_client(model, batches):
     return tff.sequence_reduce(batches, model, train_on_one_batch)
Example #6
0
def local_train(initial_model, learning_rate, all_batches, classes):
    @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model, batch):
        return batch_train(model, batch, learning_rate, classes)

    return tff.sequence_reduce(all_batches, initial_model, batch_fn)
Example #7
0
 def sum_floats(sequence):
     return tff.sequence_reduce(sequence, 0., add_floats)