Exemplo n.º 1
0
 def _apply_dense(self, grad, var):
     lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
     lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype)
     wzero = self.get_slot(var, "wzero")
     prox = tf_utils.prox_L2(var - lr_t*grad, wzero, lr_t, lamb_t)   #Implement proximal operation here
     var_update = state_ops.assign(var, prox)
     return control_flow_ops.group(*[var_update, ])
Exemplo n.º 2
0
    def _apply_dense(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype)

        f_w_0 = self.get_slot(var, "f_w_0")
        vzero = self.get_slot(var, "vzero")
        wzero = self.get_slot(var, "wzero")
        v_n_s = grad - f_w_0 + vzero
        v_t = var - lr_t * v_n_s
        prox = tf_utils.prox_L2(v_t, wzero, lr_t, lamb_t)
        var_update = state_ops.assign(var, prox)

        return control_flow_ops.group(*[
            var_update,
        ])
Exemplo n.º 3
0
    def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype)

        vzero = self.get_slot(var, "vzero")
        preG = self.get_slot(var, "preG")
        wzero = self.get_slot(var, "wzero")

        v_n_s = self.get_slot(var, "temp")

        # v_n_s = grad - preG + vzero
        temp = state_ops.assign(v_n_s, grad.values)
        with ops.control_dependencies([temp]):
            vns_update = state_ops.scatter_add(
                v_n_s, grad.indices, vzero - preG)
        with ops.control_dependencies([vns_update]):
            v_update = state_ops.assign(vzero, temp)
        v_t = var - lr_t * temp
        #prox = tf_utils.prox_L2(var - lr_t * v_n_s, lamb_t)
        prox = tf_utils.prox_L2(v_t, wzero, lr_t, lamb_t)
        var_update = state_ops.assign(var, prox)

        return control_flow_ops.group(*[var_update, v_update, ])
Exemplo n.º 4
0
    def solve_inner(self, optimizer, data, num_epochs=1, batch_size=32):
        '''Solves local optimization problem'''
        if (batch_size == 0):  # Full data or batch_size
            # print("Full dataset")
            batch_size = len(data['y'])

        if(optimizer == "fedavg"):
            for _ in trange(num_epochs, desc='Epoch: ', leave=False, ncols=120):
                for X, y in batch_data(data, batch_size):
                    with self.graph.as_default():
                        self.sess.run(self.train_op, feed_dict={
                                      self.features: X, self.labels: y})

        if(optimizer == "fedprox" or optimizer == "fedsgd"):
            data_x, data_y = suffer_data(data)
            for _ in range(num_epochs):  # t = 1,2,3,4,5,...m
                X, y = get_random_batch_sample(data_x, data_y, batch_size)
                with self.graph.as_default():
                    self.sess.run(self.train_op, feed_dict={
                        self.features: X, self.labels: y})

        if(optimizer == "fedsarah" or optimizer == "fedsvrg"):
            data_x, data_y = suffer_data(data)

            wzero = self.get_params()
            w1 = wzero - self.optimizer._lr * np.array(self.vzero)
            w1 = prox_L2(np.array(w1), np.array(wzero),self.optimizer._lr, self.optimizer._lamb)
            self.set_params(w1)

            for e in range(num_epochs-1):  # t = 1,2,3,4,5,...m
                X, y = get_random_batch_sample(data_x, data_y, batch_size)
                with self.graph.as_default():
                    # get the current weight
                    if(optimizer == "fedsvrg"):
                        current_weight = self.get_params()
                    
                        # calculate fw0 first:
                        self.set_params(wzero)
                        fwzero = self.sess.run(self.grads, feed_dict={self.features: X, self.labels: y})
                        self.optimizer.set_fwzero(fwzero, self)

                        # return the current weight to the model
                        self.set_params(current_weight)
                        self.sess.run(self.train_op, feed_dict={
                            self.features: X, self.labels: y})
                    elif(optimizer == "fedsarah"):
                        if(e == 0):
                            self.set_params(wzero)
                            grad_w0 = self.sess.run(self.grads, feed_dict={
                                                    self.features: X, self.labels: y})  # grad w0)
                            self.optimizer.set_preG(grad_w0, self)

                            self.set_params(w1)
                            preW = self.get_params()   # previous is w1

                            self.sess.run(self.train_op, feed_dict={
                                self.features: X, self.labels: y})
                        else: # == w1
                            curW = self.get_params()

                            # get previous grad
                            self.set_params(preW)
                            grad_preW = self.sess.run(self.grads, feed_dict={self.features: X, self.labels: y})  # grad w0)
                            self.optimizer.set_preG(grad_preW, self)
                            preW = curW

                            # return back curent grad 
                            self.set_params(curW)
                            self.sess.run(self.train_op, feed_dict={self.features: X, self.labels: y})
        soln = self.get_params()
        comp = num_epochs * \
            (len(data['y'])//batch_size) * batch_size * self.flops
        return soln, comp