def testZeroSketch(self): d = 100 c = 20 r = 5 a = CSVec(d, c, r, **self.csvecArgs) vec = torch.rand(d).to(self.device) a += vec zeros = torch.zeros((r, c)).to(self.device) self.assertFalse(torch.allclose(a.table, zeros)) a.zero() self.assertTrue(torch.allclose(a.table, zeros))
class TopHCS(object): """ Represents one worker""" def __init__(self, d, c, r, h, numBlocks, device='cpu'): self.h, self.d = h, d self.device = device self.topH = torch.zeros(d, dtype=torch.float, device=self.device) self.csvec = CSVec(d=d, c=c, r=r, numBlocks=numBlocks, device=self.device) def zero(self): """ Clear csvec and topH tensor """ self.csvec.zero() self.topH = torch.zeros(self.d, dtype=torch.float, device=self.device) # formerly store(...) def accumulateVec(self, vec): """ Compresses vector """ """ Save top-h elements in self.topH, sketch bottom d-h elements """ """ csvec and topH should be zero before storing """ # assert(self.topH.nonzero().numel() == 0) # changed this for commefficient optimizer self.topH = topk(vec, self.h).to(self.device) self.csvec.accumulateVec((vec - self.topH).to(self.device)) def accumulateTable(self, table): if table.size() != self.csvec.table.size(): msg = "Passed in table has size {}, expecting {}" raise ValueError(msg.format(table.size(), self.csvec.table.size())) else: self.csvec.accumulateTable(table) @classmethod def topKSum(cls, workers, k, unSketchNum=0): assert isinstance(workers, list), "workers must be a list" sketchSum = copy.deepcopy(workers[0].csvec) sketchSum.zero() topHSum = torch.zeros_like(workers[0].topH) for w in workers: sketchSum.accumulateTable(w.csvec.table) topHSum += w.topH d = len(topHSum) unSketchNum = d if (unSketchNum == 0) else unSketchNum unSketchedSum = sketchSum.unSketch(k=unSketchNum) if topHSum.size() != unSketchedSum.size(): msg = "topHSum has size {}, unSketchedSum size {}" raise ValueError(msg.format(topHSum.size(), unSketchedSum.size())) ret = topk(topHSum + unSketchedSum, k) return ret
# Ramp up alpha if epoch <= 5: alpha_t = alpha + epoch * (0.3 - alpha) / 5 elif epoch > 5 and epoch <= 10: alpha_t = 0.3 - (epoch - 5) * (0.3 - alpha) / 5 else: alpha_t = alpha table_rx = alpha * (1.0 / N) * table_rx S.table = torch.tensor(table_rx) S_e = S + S_e unsketched = S_e.unSketch(k=k) np_unsketched = unsketched.numpy() S.zero() S.accumulateVec(unsketched) S_e.table = S_e.table - S.table # Rehape unsketched print(unsketched.shape) shapes = [ model.trainable_weights[i].shape for i in range(len(model.trainable_weights)) ] grad_tx = [] n_prev = 0 for i in range(len(shapes)): n = n_prev + tf.math.reduce_prod(shapes[i])