コード例 #1
0
    def _apply_transformation(self, vecs_and_vars, transform):
        """Applies an block-wise transformation to the corresponding vectors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transform: A function of the form f(fb, vec), where vec is the vector
          to transform and fb is its corresponding block in the matrix, that
          returns the transformed vector.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """

        vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)

        trans_vecs = utils.SequenceDict()

        for params, fb in self._layers.fisher_blocks.items():
            trans_vecs[params] = transform(fb, vecs[params])

        return [(trans_vecs[var], var) for _, var in vecs_and_vars]
コード例 #2
0
 def testSetItemMultipleKeys(self):
     seq_dict = utils.SequenceDict()
     keys = ('a', 'b', 'c')
     values = ('foo', 'bar', 'baz')
     seq_dict[keys] = values
     self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
コード例 #3
0
 def testSetItemSingleKey(self):
     seq_dict = utils.SequenceDict()
     seq_dict['a'] = 'foo'
     self.assertEqual([('a', 'foo')], seq_dict.items())
コード例 #4
0
 def testGetItemMultipleKeys(self):
     seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
     self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
コード例 #5
0
 def testGetItemSingleKey(self):
     seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
     self.assertEqual('foo', seq_dict['a'])
コード例 #6
0
 def testSequenceDictInitWithIterable(self):
     reg_dict = {'a': 'foo', 'b': 'bar'}
     itr = zip(reg_dict.keys(), reg_dict.values())
     seq_dict = utils.SequenceDict(itr)
     self.assertEqual(reg_dict, seq_dict._dict)
コード例 #7
0
 def testSequenceDictInit(self):
     seq_dict = utils.SequenceDict()
     self.assertFalse(seq_dict._dict)