예제 #1
0
 def __init__(self, ):
     """
     Args:
         transfer_variable:
         converge_type: see federatedml/optim/convergence.py
         tolerate_val:
     """
     self.loss_scatter = loss_scatter.Server()
     self.has_converged = has_converged.Server()
예제 #2
0
파일: enter_point.py 프로젝트: zpskt/FATE
    def __init__(self, trans_var):
        super().__init__(trans_var=trans_var)
        self.model = None

        self.aggregator = secure_mean_aggregator.Server(self.transfer_variable.secure_aggregator_trans_var)
        self.loss_scatter = loss_scatter.Server(self.transfer_variable.loss_scatter_trans_var)
        self.has_converged = has_converged.Server(self.transfer_variable.has_converged_trans_var)

        self._summary = dict(loss_history=[], is_converged=False)
예제 #3
0
def loss_scatter_call(job_id, role, ind, *args):
    losses = args[0]
    if role == consts.ARBITER:
        losses = loss_scatter.Server().get_losses()
        return list(losses)
    elif role == consts.HOST:
        loss = losses[ind + 1]
        return loss_scatter.Client().send_loss(loss)
    else:
        loss = losses[0]
        return loss_scatter.Client().send_loss(loss)
예제 #4
0
    def __init__(self, trans_var=LegacyAggregatorTransVar()):
        self._guest_parties = trans_var.get_parties(roles=[consts.GUEST])
        self._host_parties = trans_var.get_parties(roles=[consts.HOST])
        self._client_parties = trans_var.client_parties

        self._loss_sync = loss_scatter.Server(trans_var.loss_scatter)
        self._converge_sync = has_converged.Server(trans_var.has_converged)
        self._model_scatter = model_scatter.Server(trans_var.model_scatter)
        self._model_broadcaster = model_broadcaster.Server(
            trans_var.model_broadcaster)
        self._random_padding_cipher = random_padding_cipher.Server(
            trans_var.random_padding_cipher)
예제 #5
0
def server_init_model(self, param):
    self.aggregate_iteration_num = 0
    self.aggregator = secure_mean_aggregator.Server(
        self.transfer_variable.secure_aggregator_trans_var)
    self.loss_scatter = loss_scatter.Server(
        self.transfer_variable.loss_scatter_trans_var)
    self.has_converged = has_converged.Server(
        self.transfer_variable.has_converged_trans_var)

    self._summary = dict(loss_history=[], is_converged=False)

    self.param = param
    self.enable_secure_aggregate = param.secure_aggregate
    self.max_aggregate_iteration_num = param.max_iter
    early_stop = self.model_param.early_stop
    self.converge_func = converge_func_factory(early_stop.converge_func,
                                               early_stop.eps).is_converge
    self.loss_consumed = early_stop.converge_func != "weight_diff"
 def __init__(self, verbose=False):
     self.aggregator = secure_sum_aggregator.Server(
         enable_secure_aggregate=True)
     self.scatter = loss_scatter.Server()
     self.verbose = verbose