def _apply_dense(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype) wzero = self.get_slot(var, "wzero") prox = tf_utils.prox_L2(var - lr_t*grad, wzero, lr_t, lamb_t) #Implement proximal operation here var_update = state_ops.assign(var, prox) return control_flow_ops.group(*[var_update, ])
def _apply_dense(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype) f_w_0 = self.get_slot(var, "f_w_0") vzero = self.get_slot(var, "vzero") wzero = self.get_slot(var, "wzero") v_n_s = grad - f_w_0 + vzero v_t = var - lr_t * v_n_s prox = tf_utils.prox_L2(v_t, wzero, lr_t, lamb_t) var_update = state_ops.assign(var, prox) return control_flow_ops.group(*[ var_update, ])
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) lamb_t = math_ops.cast(self._lamb_t, var.dtype.base_dtype) vzero = self.get_slot(var, "vzero") preG = self.get_slot(var, "preG") wzero = self.get_slot(var, "wzero") v_n_s = self.get_slot(var, "temp") # v_n_s = grad - preG + vzero temp = state_ops.assign(v_n_s, grad.values) with ops.control_dependencies([temp]): vns_update = state_ops.scatter_add( v_n_s, grad.indices, vzero - preG) with ops.control_dependencies([vns_update]): v_update = state_ops.assign(vzero, temp) v_t = var - lr_t * temp #prox = tf_utils.prox_L2(var - lr_t * v_n_s, lamb_t) prox = tf_utils.prox_L2(v_t, wzero, lr_t, lamb_t) var_update = state_ops.assign(var, prox) return control_flow_ops.group(*[var_update, v_update, ])
def solve_inner(self, optimizer, data, num_epochs=1, batch_size=32): '''Solves local optimization problem''' if (batch_size == 0): # Full data or batch_size # print("Full dataset") batch_size = len(data['y']) if(optimizer == "fedavg"): for _ in trange(num_epochs, desc='Epoch: ', leave=False, ncols=120): for X, y in batch_data(data, batch_size): with self.graph.as_default(): self.sess.run(self.train_op, feed_dict={ self.features: X, self.labels: y}) if(optimizer == "fedprox" or optimizer == "fedsgd"): data_x, data_y = suffer_data(data) for _ in range(num_epochs): # t = 1,2,3,4,5,...m X, y = get_random_batch_sample(data_x, data_y, batch_size) with self.graph.as_default(): self.sess.run(self.train_op, feed_dict={ self.features: X, self.labels: y}) if(optimizer == "fedsarah" or optimizer == "fedsvrg"): data_x, data_y = suffer_data(data) wzero = self.get_params() w1 = wzero - self.optimizer._lr * np.array(self.vzero) w1 = prox_L2(np.array(w1), np.array(wzero),self.optimizer._lr, self.optimizer._lamb) self.set_params(w1) for e in range(num_epochs-1): # t = 1,2,3,4,5,...m X, y = get_random_batch_sample(data_x, data_y, batch_size) with self.graph.as_default(): # get the current weight if(optimizer == "fedsvrg"): current_weight = self.get_params() # calculate fw0 first: self.set_params(wzero) fwzero = self.sess.run(self.grads, feed_dict={self.features: X, self.labels: y}) self.optimizer.set_fwzero(fwzero, self) # return the current weight to the model self.set_params(current_weight) self.sess.run(self.train_op, feed_dict={ self.features: X, self.labels: y}) elif(optimizer == "fedsarah"): if(e == 0): self.set_params(wzero) grad_w0 = self.sess.run(self.grads, feed_dict={ self.features: X, self.labels: y}) # grad w0) self.optimizer.set_preG(grad_w0, self) self.set_params(w1) preW = self.get_params() # previous is w1 self.sess.run(self.train_op, feed_dict={ self.features: X, self.labels: y}) else: # == w1 curW = self.get_params() # get previous grad self.set_params(preW) grad_preW = self.sess.run(self.grads, feed_dict={self.features: X, self.labels: y}) # grad w0) self.optimizer.set_preG(grad_preW, self) preW = curW # return back curent grad self.set_params(curW) self.sess.run(self.train_op, feed_dict={self.features: X, self.labels: y}) soln = self.get_params() comp = num_epochs * \ (len(data['y'])//batch_size) * batch_size * self.flops return soln, comp