def _server_helper_true_topk(gradient, Vvelocity, Verror, args, lr): assert args.error_type == "virtual" rho = args.virtual_momentum # Vvelocity = rho * Vvelocity + gradient torch.add(gradient, Vvelocity, alpha=rho, out=Vvelocity) Verror += Vvelocity update = _topk(Verror, k=args.k) # we need to do momentum factor masking on the worker # momentum vectors for true_topk, which we can't do in # the worker because we don't know the global topk yet global g_participating_clients global g_client_velocities if args.local_momentum > 0: rows = g_participating_clients.view(-1, 1) nz = update.nonzero()[:, 0] g_client_velocities[rows, nz] = 0 # error feedback Verror[update.nonzero()] = 0 # momentum factor masking Vvelocity[update.nonzero()] = 0 return update * lr, Vvelocity, Verror
def get_new_worker_weights(ps_weights, worker_weights, args): device = args.device ps_weights = ps_weights.to(device) worker_weights = worker_weights.to(device) # we'll update the old worker_weights with a possibly compressed # version of diff_vec diff_vec = ps_weights - worker_weights if args.do_topk_down: weight_update = _topk(diff_vec, k=args.k) else: weight_update = diff_vec new_worker_weights = worker_weights + weight_update return new_worker_weights
def local_step(model, velocity, error, batch, compute_loss, args): # g is a (possibly compressed) gradient g, results = forward_grad(model, batch, compute_loss, args) # locally, we need to deal with the sum of gradients across # examples, since we will torch.distributed.reduce the to_transmits, g *= batch[0].size(0) # if needed, do local momentum if args.local_momentum > 0: # this does velocity[:] = m * velocity + g, but twice as fast torch.add(g, velocity, alpha=args.local_momentum, out=velocity) # if needed, do local error correction if args.error_type == "local": error += velocity if velocity is not None else g to_transmit = error else: to_transmit = velocity if velocity is not None else g if args.mode == "local_topk": assert args.error_type in ["local", "none"] # topk is impossibly slow on CPU, very fast on GPU to_transmit = _topk(to_transmit.to(args.device), k=args.k) nz = to_transmit.nonzero() if error is not None: # error feedback error[nz] = 0 # if we're doing local momentum, do momentum factor masking if args.local_momentum > 0: velocity[nz] = 0 # sketched sgd with local error accumulation doesn't really make # sense, since when we send a sketch we don't know what portion # of the sketch is the "error" if error is not None: assert args.mode not in ["sketch", "uncompressed"] # we want to do momentum factor masking for all the compression # methods, but that's not possible to do for sketching, since # it's unknown which coordinates to mask out if velocity is not None: assert args.mode != "sketch" return results, to_transmit