def stream_model_updates(experiment_id, model_updates, client, secret,
                          task_id):
     yield globalserver_pb2.ModelUpdate(client=client,
                                        secret=secret,
                                        experiment_id=experiment_id,
                                        task_id=task_id)
     yield globalserver_pb2.ModelUpdate(model_update=model_updates)
Esempio n. 2
0
def stream_model(experiment_id, model, global_weights, client, secret, task_id):
    if int(os.getenv('SERVER', 1)):
        gradient = get_gradient(get_weights(model), global_weights)
        yield globalserver_pb2.ModelUpdate(client=client, secret=secret, experiment_id=experiment_id, task_id=task_id)
        yield globalserver_pb2.ModelUpdate(model_update=array_to_bytes(gradient))
    else:
        yield globalserver_pb2.ModelUpdate(client=client, secret=secret, experiment_id=experiment_id, task_id=task_id)
        yield globalserver_pb2.ModelUpdate(model_update=array_to_bytes(global_weights))
Esempio n. 3
0
def stream_model_P2P(experiment_id, model, client, secret, task_id):
    model_dict = dict()

    # save trees for visualization of the model
    model_dict['trees'] = str(model.get_dump())

    # save model object itself, str format. pickle returns bytes format, we transform it to str
    model_dict['pickle'] = str(pickle.dumps(model))

    yield globalserver_pb2.ModelUpdate(client=client, secret=secret, experiment_id=experiment_id, task_id=task_id)
    yield globalserver_pb2.ModelUpdate(model_update=json.dumps(model_dict).encode('utf-8'))
Esempio n. 4
0
def stream_model_RF(experiment_id, model, client, secret, task_id):
    # the RF_train_model filled this attribute with a json-like string containing histogram data
    model_update = json.dumps(model.model_update)

    yield globalserver_pb2.ModelUpdate(client=client, secret=secret, experiment_id=experiment_id, task_id=task_id)
    yield globalserver_pb2.ModelUpdate(model_update=model_update.encode('utf-8'))