示例#1
0
        def testZeroSketch(self):
            d = 100
            c = 20
            r = 5
            a = CSVec(d, c, r, **self.csvecArgs)
            vec = torch.rand(d).to(self.device)
            a += vec

            zeros = torch.zeros((r, c)).to(self.device)
            self.assertFalse(torch.allclose(a.table, zeros))

            a.zero()
            self.assertTrue(torch.allclose(a.table, zeros))
示例#2
0
class TopHCS(object):
	""" Represents one worker"""
	def __init__(self, d, c, r, h, numBlocks, device='cpu'): 
		self.h, self.d = h, d 
		self.device = device
		self.topH = torch.zeros(d, dtype=torch.float, device=self.device)
		self.csvec = CSVec(d=d, c=c, r=r, numBlocks=numBlocks, device=self.device)
	
	def zero(self):
		""" Clear csvec and topH tensor """
		self.csvec.zero()
		self.topH = torch.zeros(self.d, dtype=torch.float, device=self.device)
        
        # formerly store(...)
	def accumulateVec(self, vec):
		""" Compresses vector """
		""" Save top-h elements in self.topH, sketch bottom d-h elements """
		""" csvec and topH should be zero before storing """
		# assert(self.topH.nonzero().numel() == 0)
                # changed this for commefficient optimizer
		self.topH = topk(vec, self.h).to(self.device)
		self.csvec.accumulateVec((vec - self.topH).to(self.device))
        
	def accumulateTable(self, table):
		if table.size() != self.csvec.table.size():
			msg = "Passed in table has size {}, expecting {}"
			raise ValueError(msg.format(table.size(), self.csvec.table.size()))
		else:
			self.csvec.accumulateTable(table)

	@classmethod
	def topKSum(cls, workers, k, unSketchNum=0):
		assert isinstance(workers, list), "workers must be a list"
		sketchSum = copy.deepcopy(workers[0].csvec) 
		sketchSum.zero()
		topHSum = torch.zeros_like(workers[0].topH)
		for w in workers:
			sketchSum.accumulateTable(w.csvec.table)
			topHSum += w.topH
		d = len(topHSum)
		unSketchNum = d if (unSketchNum == 0) else unSketchNum
		unSketchedSum = sketchSum.unSketch(k=unSketchNum) 
		if topHSum.size() != unSketchedSum.size():
			msg = "topHSum has size {}, unSketchedSum size {}"
			raise ValueError(msg.format(topHSum.size(), unSketchedSum.size()))
		ret = topk(topHSum + unSketchedSum, k)
		return ret
示例#3
0
            # Ramp up alpha
            if epoch <= 5:
                alpha_t = alpha + epoch * (0.3 - alpha) / 5
            elif epoch > 5 and epoch <= 10:
                alpha_t = 0.3 - (epoch - 5) * (0.3 - alpha) / 5
            else:
                alpha_t = alpha

            table_rx = alpha * (1.0 / N) * table_rx
            S.table = torch.tensor(table_rx)

            S_e = S + S_e
            unsketched = S_e.unSketch(k=k)
            np_unsketched = unsketched.numpy()

            S.zero()
            S.accumulateVec(unsketched)

            S_e.table = S_e.table - S.table

            # Rehape unsketched
            print(unsketched.shape)
            shapes = [
                model.trainable_weights[i].shape
                for i in range(len(model.trainable_weights))
            ]

            grad_tx = []
            n_prev = 0
            for i in range(len(shapes)):
                n = n_prev + tf.math.reduce_prod(shapes[i])