예제 #1
0
    def _compute_transformation(self, vecs_and_vars, transform):
        """Computes a block-wise transformation of a list of vectors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transform: A function of the form f(fb, vec), that
          returns the transformed vector, where vec is the vector
          to transform and fb is its corresponding block.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """

        vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)

        def make_thunk(fb, params):
            return lambda: transform(fb, vecs[params])

        thunks = tuple(
            make_thunk(fb, params)
            for params, fb in self.layers.fisher_blocks.items())

        params_list = tuple(params
                            for params, _ in self.layers.fisher_blocks.items())

        results = self._place_and_compute_transformation_thunks(
            thunks, params_list)

        trans_vecs = utils.SequenceDict()
        for params, result in zip(self.layers.fisher_blocks, results):
            trans_vecs[params] = result

        return [(trans_vecs[var], var) for _, var in vecs_and_vars]
예제 #2
0
    def _apply_transformation(self, vecs_and_vars, transform):
        """Applies an block-wise transformation to the corresponding vectors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transform: A function of the form f(fb, vec), where vec is the vector
          to transform and fb is its corresponding block in the matrix, that
          returns the transformed vector.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """

        vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)

        trans_vecs = utils.SequenceDict()

        for params, fb in self._layers.fisher_blocks.items():
            trans_vecs[params] = transform(fb, vecs[params])

        return [(trans_vecs[var], var) for _, var in vecs_and_vars]
예제 #3
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testSetItemMultipleKeys(self):
   seq_dict = utils.SequenceDict()
   keys = ('a', 'b', 'c')
   values = ('foo', 'bar', 'baz')
   seq_dict[keys] = values
   self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
예제 #4
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testSetItemSingleKey(self):
   seq_dict = utils.SequenceDict()
   seq_dict['a'] = 'foo'
   self.assertEqual([('a', 'foo')], seq_dict.items())
예제 #5
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testGetItemMultipleKeys(self):
   seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
   self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
예제 #6
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testGetItemSingleKey(self):
   seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
   self.assertEqual('foo', seq_dict['a'])
예제 #7
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testSequenceDictInitWithIterable(self):
   reg_dict = {'a': 'foo', 'b': 'bar'}
   itr = zip(reg_dict.keys(), reg_dict.values())
   seq_dict = utils.SequenceDict(itr)
   self.assertEqual(reg_dict, seq_dict._dict)
예제 #8
0
파일: utils_test.py 프로젝트: leox1v/kfac
 def testSequenceDictInit(self):
   seq_dict = utils.SequenceDict()
   self.assertFalse(seq_dict._dict)