def testZeroSketch(self): d = 100 c = 20 r = 5 a = CSVec(d, c, r, **self.csvecArgs) vec = torch.rand(d).to(self.device) a.accumulateVec(vec) zeros = torch.zeros((r, c)).to(self.device) self.assertFalse(torch.allclose(a.table, zeros)) a.zero() self.assertTrue(torch.allclose(a.table, zeros))
def testSketchSum(self): d = 5 c = 10000 r = 20 summed = CSVec(d, c, r, **self.csvecArgs) for i in range(d): vec = torch.zeros(d).to(self.device) vec[i] = 1 sketch = CSVec(d, c, r, **self.csvecArgs) sketch.accumulateVec(vec) summed += sketch recovered = summed.unSketch(k=d) trueSum = torch.ones(d).to(self.device) self.assertTrue(torch.allclose(recovered, trueSum))
class TopHCS(object): """ Represents one worker""" def __init__(self, d, c, r, h, numBlocks, device='cpu'): self.h, self.d = h, d self.device = device self.topH = torch.zeros(d, dtype=torch.float, device=self.device) self.csvec = CSVec(d=d, c=c, r=r, numBlocks=numBlocks, device=self.device) def zero(self): """ Clear csvec and topH tensor """ self.csvec.zero() self.topH = torch.zeros(self.d, dtype=torch.float, device=self.device) # formerly store(...) def accumulateVec(self, vec): """ Compresses vector """ """ Save top-h elements in self.topH, sketch bottom d-h elements """ """ csvec and topH should be zero before storing """ # assert(self.topH.nonzero().numel() == 0) # changed this for commefficient optimizer self.topH = topk(vec, self.h).to(self.device) self.csvec.accumulateVec((vec - self.topH).to(self.device)) def accumulateTable(self, table): if table.size() != self.csvec.table.size(): msg = "Passed in table has size {}, expecting {}" raise ValueError(msg.format(table.size(), self.csvec.table.size())) else: self.csvec.accumulateTable(table) @classmethod def topKSum(cls, workers, k, unSketchNum=0): assert isinstance(workers, list), "workers must be a list" sketchSum = copy.deepcopy(workers[0].csvec) sketchSum.zero() topHSum = torch.zeros_like(workers[0].topH) for w in workers: sketchSum.accumulateTable(w.csvec.table) topHSum += w.topH d = len(topHSum) unSketchNum = d if (unSketchNum == 0) else unSketchNum unSketchedSum = sketchSum.unSketch(k=unSketchNum) if topHSum.size() != unSketchedSum.size(): msg = "topHSum has size {}, unSketchedSum size {}" raise ValueError(msg.format(topHSum.size(), unSketchedSum.size())) ret = topk(topHSum + unSketchedSum, k) return ret
def testUnsketch(self): # make sure heavy hitter recovery works correctly # use a gigantic sketch so there's no chance of collision d = 5 c = 10000 r = 20 a = CSVec(d, c, r, **self.csvecArgs) vec = torch.rand(d).to(self.device) a.accumulateVec(vec) with self.subTest(method="topk"): recovered = a.unSketch(k=d) self.assertTrue(torch.allclose(recovered, vec)) with self.subTest(method="epsilon"): thr = vec.abs().min() * 0.9 recovered = a.unSketch(epsilon=thr / vec.norm()) self.assertTrue(torch.allclose(recovered, vec))
def testSketchVec(self): # sketch a vector with all zeros except a single 1 # then the table should be zeros everywhere except a single # 1 in each row d = 100 c = 1 r = 5 a = CSVec(d=d, c=c, r=r, **self.csvecArgs) vec = torch.zeros(d).to(self.device) vec[0] = 1 a.accumulateVec(vec) # make sure the sketch only has one nonzero entry per row for i in range(r): with self.subTest(row=i): self.assertEqual(a.table[i, :].nonzero().numel(), 1) # make sure each row sums to +-1 summed = a.table.abs().sum(dim=1).view(-1) ones = torch.ones(r).to(self.device) self.assertTrue(torch.allclose(summed, ones))
def forward_grad(model, batch, compute_loss, args, compute_grad=True): device = args.device # divide up batch (for gradient accumulation when memory constrained) #num_shards = args.num_train_batch_shards # need the max(1, ...) since the last batch in an epoch might be small #microbatch_size = max(1, batch[0].size()[0] // num_shards) if args.microbatch_size > 0: microbatch_size = min(batch[0].size()[0], args.microbatch_size) else: microbatch_size = batch[0].size()[0] # accumulators for the loss & metric values accum_loss = 0 accum_metrics = None num_iters = math.ceil(batch[0].size()[0] / microbatch_size) for i in range(num_iters): # extract current microbatch start = i * microbatch_size end = (i+1) * microbatch_size microbatch = [t[start:end] for t in batch] # forward pass loss, *metrics = compute_loss(model, microbatch, args) # if first time through, we find out how many metrics there are if accum_metrics is None: accum_metrics = [0 for _ in metrics] # accumulate loss & metrics, weighted by how many data points # were actually used accum_loss += loss.item() * microbatch[0].size()[0] for i, m in enumerate(metrics): accum_metrics[i] += m.item() * microbatch[0].size()[0] # backward pass if compute_grad: loss.backward() # gradient clipping if compute_grad and args.max_grad_norm is not None and args.mode not in ["sketch"]: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm * num_iters) # "average" here is over the data in the batch average_loss = accum_loss / batch[0].size()[0] average_metrics = [m / batch[0].size()[0] for m in accum_metrics] results = [average_loss] + average_metrics if not compute_grad: return results grad = get_grad(model, args) if args.do_dp: grad = clip_grad(args.l2_norm_clip, grad) if args.dp_mode == "worker": noise = torch.normal(mean=0, std=args.noise_multiplier, size=grad.size()).to(args.device) noise *= np.sqrt(args.num_workers) grad += noise # compress the gradient if needed if args.mode == "sketch": sketch = CSVec(d=args.grad_size, c=args.num_cols, r=args.num_rows, device=args.device, numBlocks=args.num_blocks) sketch.accumulateVec(grad) # gradient clipping if compute_grad and args.max_grad_norm is not None: sketch = clip_grad(args.max_grad_norm, sketch) g = sketch.table elif args.mode == "true_topk": g = grad elif args.mode == "local_topk": # ideally we'd return the compressed version of the gradient, # i.e. _topk(grad, k=args.k). However, for sketching we do momentum # in the sketch, whereas for topk we do momentum before taking topk # so we have to return an inconsistent quantity here g = grad elif args.mode == "fedavg": # logic for doing fedavg happens in process_batch g = grad elif args.mode == "uncompressed": g = grad return g, results
if epoch <= 5: alpha_t = alpha + epoch * (0.3 - alpha) / 5 elif epoch > 5 and epoch <= 10: alpha_t = 0.3 - (epoch - 5) * (0.3 - alpha) / 5 else: alpha_t = alpha table_rx = alpha * (1.0 / N) * table_rx S.table = torch.tensor(table_rx) S_e = S + S_e unsketched = S_e.unSketch(k=k) np_unsketched = unsketched.numpy() S.zero() S.accumulateVec(unsketched) S_e.table = S_e.table - S.table # Rehape unsketched print(unsketched.shape) shapes = [ model.trainable_weights[i].shape for i in range(len(model.trainable_weights)) ] grad_tx = [] n_prev = 0 for i in range(len(shapes)): n = n_prev + tf.math.reduce_prod(shapes[i]) grad_tx.append(