def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. Args: vecs_and_vars: List of (vector, variable) pairs. transform: A function of the form f(fb, vec), where vec is the vector to transform and fb is its corresponding block in the matrix, that returns the transformed vector. Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) trans_vecs = utils.SequenceDict() for params, fb in self._layers.fisher_blocks.items(): trans_vecs[params] = transform(fb, vecs[params]) return [(trans_vecs[var], var) for _, var in vecs_and_vars]
def testSetItemMultipleKeys(self): seq_dict = utils.SequenceDict() keys = ('a', 'b', 'c') values = ('foo', 'bar', 'baz') seq_dict[keys] = values self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
def testSetItemSingleKey(self): seq_dict = utils.SequenceDict() seq_dict['a'] = 'foo' self.assertEqual([('a', 'foo')], seq_dict.items())
def testGetItemMultipleKeys(self): seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
def testGetItemSingleKey(self): seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) self.assertEqual('foo', seq_dict['a'])
def testSequenceDictInitWithIterable(self): reg_dict = {'a': 'foo', 'b': 'bar'} itr = zip(reg_dict.keys(), reg_dict.values()) seq_dict = utils.SequenceDict(itr) self.assertEqual(reg_dict, seq_dict._dict)
def testSequenceDictInit(self): seq_dict = utils.SequenceDict() self.assertFalse(seq_dict._dict)