Example #1
0
  def Test(self):
    np.random.seed(1)

    if dtype_ in (np.float32, np.float64):
      x = np.random.uniform(low=-1.0, high=1.0,
                            size=np.prod(shape_)).reshape(shape_).astype(dtype_)
    elif dtype == np.complex64:
      x = np.random.uniform(low=-1.0, high=1.0,
                        size=np.prod(shape_)).reshape(shape_).astype(np.float32)
      + 1j * np.random.uniform(low=-1.0, high=1.0,
                        size=np.prod(shape_)).reshape(shape_).astype(np.float32)
    else:
      x = np.random.uniform(low=-1.0, high=1.0,
                        size=np.prod(shape_)).reshape(shape_).astype(np.float64)
      + 1j * np.random.uniform(low=-1.0, high=1.0,
                        size=np.prod(shape_)).reshape(shape_).astype(np.float64)

    for compute_uv in False, True:
      for full_matrices in False, True:
        with self.test_session():
          if x.ndim == 2:
            if compute_uv:
              tf_s, tf_u, tf_v = tf.svd(tf.constant(x),
                                        compute_uv=compute_uv,
                                        full_matrices=full_matrices)
            else:
              tf_s = tf.svd(tf.constant(x),
                            compute_uv=compute_uv,
                            full_matrices=full_matrices)
          else:
            if compute_uv:
              tf_s, tf_u, tf_v = tf.batch_svd(
                  tf.constant(x),
                  compute_uv=compute_uv,
                  full_matrices=full_matrices)
            else:
              tf_s = tf.batch_svd(
                  tf.constant(x),
                  compute_uv=compute_uv,
                  full_matrices=full_matrices)
          if compute_uv:
            np_u, np_s, np_v = np.linalg.svd(x,
                                             compute_uv=compute_uv,
                                             full_matrices=full_matrices)
          else:
            np_s = np.linalg.svd(x,
                                 compute_uv=compute_uv,
                                 full_matrices=full_matrices)
          CompareSingularValues(self, np_s, tf_s.eval())
          if compute_uv:
            CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:]))
            CompareSingularVectors(self, np.conj(np.swapaxes(np_v, -2, -1)),
                                   tf_v.eval(), min(shape_[-2:]))
            CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices)
            CheckUnitary(self, tf_u)
            CheckUnitary(self, tf_v)
Example #2
0
 def Test(self):
   np.random.seed(1)
   x = np.random.uniform(
       low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
   if dtype_ == np.float32:
     atol = 1e-4
   else:
     atol = 1e-14
   for compute_uv in False, True:
     for full_matrices in False, True:
       with self.test_session():
         if x.ndim == 2:
           if compute_uv:
             tf_s, tf_u, tf_v = tf.svd(tf.constant(x),
                                       compute_uv=compute_uv,
                                       full_matrices=full_matrices)
           else:
             tf_s = tf.svd(tf.constant(x),
                           compute_uv=compute_uv,
                           full_matrices=full_matrices)
         else:
           if compute_uv:
             tf_s, tf_u, tf_v = tf.batch_svd(
                 tf.constant(x),
                 compute_uv=compute_uv,
                 full_matrices=full_matrices)
           else:
             tf_s = tf.batch_svd(
                 tf.constant(x),
                 compute_uv=compute_uv,
                 full_matrices=full_matrices)
         if compute_uv:
           np_u, np_s, np_v = np.linalg.svd(x,
                                            compute_uv=compute_uv,
                                            full_matrices=full_matrices)
         else:
           np_s = np.linalg.svd(x,
                                compute_uv=compute_uv,
                                full_matrices=full_matrices)
         self.assertAllClose(np_s, tf_s.eval(), atol=atol)
         if compute_uv:
           CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:]),
                                  atol)
           CompareSingularVectors(self, np.swapaxes(np_v, -2, -1), tf_v.eval(),
                                  min(shape_[-2:]), atol)
           CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices, atol)
           CheckUnitary(self, tf_u)
           CheckUnitary(self, tf_v)
Example #3
0
 def Test(self):
     np.random.seed(1)
     x = np.random.uniform(
         low=-1.0, high=1.0,
         size=np.prod(shape_)).reshape(shape_).astype(dtype_)
     if dtype_ == np.float32:
         atol = 1e-4
     else:
         atol = 1e-14
     for compute_uv in False, True:
         for full_matrices in False, True:
             with self.test_session():
                 if x.ndim == 2:
                     if compute_uv:
                         tf_s, tf_u, tf_v = tf.svd(
                             tf.constant(x),
                             compute_uv=compute_uv,
                             full_matrices=full_matrices)
                     else:
                         tf_s = tf.svd(tf.constant(x),
                                       compute_uv=compute_uv,
                                       full_matrices=full_matrices)
                 else:
                     if compute_uv:
                         tf_s, tf_u, tf_v = tf.batch_svd(
                             tf.constant(x),
                             compute_uv=compute_uv,
                             full_matrices=full_matrices)
                     else:
                         tf_s = tf.batch_svd(tf.constant(x),
                                             compute_uv=compute_uv,
                                             full_matrices=full_matrices)
                 if compute_uv:
                     np_u, np_s, np_v = np.linalg.svd(
                         x,
                         compute_uv=compute_uv,
                         full_matrices=full_matrices)
                 else:
                     np_s = np.linalg.svd(x,
                                          compute_uv=compute_uv,
                                          full_matrices=full_matrices)
                 self.assertAllClose(np_s, tf_s.eval(), atol=atol)
                 if compute_uv:
                     _CompareSingularVectors(self, np_u, tf_u.eval(), atol)
                     _CompareSingularVectors(self,
                                             np.swapaxes(np_v, -2, -1),
                                             tf_v.eval(), atol)
Example #4
0
 def Test(self):
     np.random.seed(1)
     x = np.random.uniform(
         low=-1.0, high=1.0,
         size=np.prod(shape_)).reshape(shape_).astype(dtype_)
     for compute_uv in False, True:
         for full_matrices in False, True:
             with self.test_session():
                 if x.ndim == 2:
                     if compute_uv:
                         tf_s, tf_u, tf_v = tf.svd(
                             tf.constant(x),
                             compute_uv=compute_uv,
                             full_matrices=full_matrices)
                     else:
                         tf_s = tf.svd(tf.constant(x),
                                       compute_uv=compute_uv,
                                       full_matrices=full_matrices)
                 else:
                     if compute_uv:
                         tf_s, tf_u, tf_v = tf.batch_svd(
                             tf.constant(x),
                             compute_uv=compute_uv,
                             full_matrices=full_matrices)
                     else:
                         tf_s = tf.batch_svd(tf.constant(x),
                                             compute_uv=compute_uv,
                                             full_matrices=full_matrices)
                 if compute_uv:
                     np_u, np_s, np_v = np.linalg.svd(
                         x,
                         compute_uv=compute_uv,
                         full_matrices=full_matrices)
                 else:
                     np_s = np.linalg.svd(x,
                                          compute_uv=compute_uv,
                                          full_matrices=full_matrices)
                 CompareSingularValues(self, np_s, tf_s.eval())
                 if compute_uv:
                     CompareSingularVectors(self, np_u, tf_u.eval(),
                                            min(shape_[-2:]))
                     CompareSingularVectors(self, np.swapaxes(np_v, -2, -1),
                                            tf_v.eval(), min(shape_[-2:]))
                     CheckApproximation(self, x, tf_u, tf_s, tf_v,
                                        full_matrices)
                     CheckUnitary(self, tf_u)
                     CheckUnitary(self, tf_v)
Example #5
0
    def testWrongDimensions(self):
        # The input to svd should be 2-dimensional tensor.
        scalar = tf.constant(1.)
        with self.assertRaises(ValueError):
            tf.svd(scalar)
        vector = tf.constant([1., 2.])
        with self.assertRaises(ValueError):
            tf.svd(vector)
        tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
        with self.assertRaises(ValueError):
            tf.svd(tensor)

        # The input to batch_svd should be a tensor of at least rank 2.
        scalar = tf.constant(1.)
        with self.assertRaises(ValueError):
            tf.batch_svd(scalar)
        vector = tf.constant([1., 2.])
        with self.assertRaises(ValueError):
            tf.batch_svd(vector)
Example #6
0
  def testWrongDimensions(self):
    # The input to svd should be 2-dimensional tensor.
    scalar = tf.constant(1.)
    with self.assertRaises(ValueError):
      tf.svd(scalar)
    vector = tf.constant([1., 2.])
    with self.assertRaises(ValueError):
      tf.svd(vector)
    tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
    with self.assertRaises(ValueError):
      tf.svd(tensor)

    # The input to batch_svd should be a tensor of at least rank 2.
    scalar = tf.constant(1.)
    with self.assertRaises(ValueError):
      tf.batch_svd(scalar)
    vector = tf.constant([1., 2.])
    with self.assertRaises(ValueError):
      tf.batch_svd(vector)
Example #7
0
    def testWrongDimensions(self):
        # The input to svd should be 2-dimensional tensor.
        scalar = tf.constant(1.)
        with self.assertRaisesRegexp(ValueError,
                                     "Shape must be rank 2 but is rank 0"):
            tf.svd(scalar)
        vector = tf.constant([1., 2.])
        with self.assertRaisesRegexp(ValueError,
                                     "Shape must be rank 2 but is rank 1"):
            tf.svd(vector)
        tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
        with self.assertRaisesRegexp(ValueError,
                                     "Shape must be rank 2 but is rank 3"):
            tf.svd(tensor)
        scalar = tf.constant(1. + 1.0j)
        with self.assertRaises(ValueError):
            tf.svd(scalar)
        vector = tf.constant([1. + 1.0j, 2. + 2.0j])
        with self.assertRaises(ValueError):
            tf.svd(vector)
        tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]],
                              [[1. + 1.0j, 2. + 2.0j], [3. + 3.0j,
                                                        4. + 4.0j]]])
        with self.assertRaises(ValueError):
            tf.svd(tensor)

        # The input to batch_svd should be a tensor of at least rank 2.
        scalar = tf.constant(1.)
        with self.assertRaisesRegexp(
                ValueError, "Shape must be at least rank 2 but is rank 0"):
            tf.batch_svd(scalar)
        vector = tf.constant([1., 2.])
        with self.assertRaisesRegexp(
                ValueError, "Shape must be at least rank 2 but is rank 1"):
            tf.batch_svd(vector)
        scalar = tf.constant(1. + 1.0j)
        with self.assertRaises(ValueError):
            tf.batch_svd(scalar)
        vector = tf.constant([1. + 1.0j, 2. + 2.0j])
        with self.assertRaises(ValueError):
            tf.batch_svd(vector)
Example #8
0
  def testWrongDimensions(self):
    # The input to svd should be 2-dimensional tensor.
    scalar = tf.constant(1.)
    with self.assertRaisesRegexp(ValueError,
                                 "Shape must be rank 2 but is rank 0"):
      tf.svd(scalar)
    vector = tf.constant([1., 2.])
    with self.assertRaisesRegexp(ValueError,
                                 "Shape must be rank 2 but is rank 1"):
      tf.svd(vector)
    tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
    with self.assertRaisesRegexp(ValueError,
                                 "Shape must be rank 2 but is rank 3"):
      tf.svd(tensor)
    scalar = tf.constant(1. + 1.0j)
    with self.assertRaises(ValueError):
      tf.svd(scalar)
    vector = tf.constant([1. + 1.0j, 2. + 2.0j])
    with self.assertRaises(ValueError):
      tf.svd(vector)
    tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]],
                          [[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]]])
    with self.assertRaises(ValueError):
      tf.svd(tensor)

    # The input to batch_svd should be a tensor of at least rank 2.
    scalar = tf.constant(1.)
    with self.assertRaisesRegexp(ValueError,
                                 "Shape must be at least rank 2 but is rank 0"):
      tf.batch_svd(scalar)
    vector = tf.constant([1., 2.])
    with self.assertRaisesRegexp(ValueError,
                                 "Shape must be at least rank 2 but is rank 1"):
      tf.batch_svd(vector)
    scalar = tf.constant(1. + 1.0j)
    with self.assertRaises(ValueError):
      tf.batch_svd(scalar)
    vector = tf.constant([1. + 1.0j, 2. + 2.0j])
    with self.assertRaises(ValueError):
      tf.batch_svd(vector)
Example #9
0
    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                                # factorization along the channels
                                # assume independence among input channels
                                # factor = B x 1 x 1 x (KH xKW x C)
                                # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(tf.reshape(
                                    fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
                                                               0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' % (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([fpropFactor, tf.ones(
                                [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat(
                                [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' % (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor, [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' % (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(
                        bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
                                          [tf.convert_to_tensor('step:'),
                                           self.global_step,
                                           tf.convert_to_tensor(
                                               'computing stats'),
                                           ])
        self.statsUpdates = statsUpdates
        return statsUpdates
Example #10
0
    def compute_stats(self, loss_sampled, var_list=None):
        """
        compute the stats values

        :param loss_sampled: ([TensorFlow Tensor]) the loss function output
        :param var_list: ([TensorFlow Tensor]) The parameters
        :return: ([TensorFlow Tensor]) stats updates
        """
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gradient_sampled = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gradient_sampled = gradient_sampled

        # remove unused variables
        gradient_sampled, varlist = zip(*[(grad, var) for (grad, var) in zip(gradient_sampled, varlist)
                                          if grad is not None])

        factors = self.get_factors(gradient_sampled, varlist)
        stats = self.get_stats(factors, varlist)

        update_ops = []
        stats_updates = {}
        stats_updates_cache = {}
        for var in varlist:
            op_type = factors[var]['opName']
            fops = factors[var]['op']
            fprop_factor = factors[var]['fpropFactors_concat']
            fprop_stats_vars = stats[var]['fprop_concat_stats']
            bprop_factor = factors[var]['bpropFactors_concat']
            bprop_stats_vars = stats[var]['bprop_concat_stats']
            svd_factors = {}
            for stats_var in fprop_stats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in stats_updates_cache:
                    batch_size = (tf.shape(fprop_factor)[0])  # batch size
                    if op_type == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        kernel_height = int(convkernel_size[0])
                        kernel_width = int(convkernel_size[1])
                        chan = int(convkernel_size[2])
                        flatten_size = int(kernel_height * kernel_width * chan)

                        operator_height = int(bprop_factor.get_shape()[1])
                        operator_width = int(bprop_factor.get_shape()[2])

                        if operator_height == 1 and operator_width == 1 and self._channel_fac:
                            # factorization along the channels
                            # assume independence among input channels
                            # factor = B x 1 x 1 x (KH xKW x C)
                            # patches = B x Oh x Ow x (KH xKW x C)
                            if len(svd_factors) == 0:
                                if KFAC_DEBUG:
                                    print(('approx %s act factor with rank-1 SVD factors' % var.name))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(tf.reshape(
                                    fprop_factor, [-1, kernel_height * kernel_width, chan]))
                                # get rank-1 approx slides
                                sqrt_s1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrt_s1  # B x KH*KW
                                full_factor_shape = fprop_factor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], kernel_height * kernel_width])
                                patches_c = V[:, :, 0] * sqrt_s1  # B x C
                                patches_c.set_shape([full_factor_shape[0], chan])
                                svd_factors[chan] = patches_c
                                svd_factors[kernel_height * kernel_width] = patches_k
                            fprop_factor = svd_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(fprop_factor, ksizes=[1, convkernel_size[
                                0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)

                            if self._approx_t2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' % var.name))
                                # T^2 terms * 1/T^2, size: B x C
                                fprop_factor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fprop_factor = tf.reshape(
                                    patches, [-1, flatten_size]) / operator_height / operator_width
                    fprop_factor_size = int(fprop_factor.get_shape()[-1])
                    if stats_var_dim == (fprop_factor_size + 1) and not self._blockdiag_bias:
                        if op_type == 'Conv2D' and not self._approx_t2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fprop_factor = tf.concat([fprop_factor, tf.ones(
                                [tf.shape(fprop_factor)[0], 1]) / operator_height / operator_width], 1)
                        else:
                            # use homogeneous coordinates
                            fprop_factor = tf.concat(
                                [fprop_factor, tf.ones([tf.shape(fprop_factor)[0], 1])], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fprop_factor, fprop_factor,
                                    transpose_a=True) / tf.cast(batch_size, tf.float32)
                    update_ops.append(cov)
                    stats_updates[stats_var] = cov
                    if op_type != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        stats_updates_cache[stats_var] = cov

            for stats_var in bprop_stats_vars:
                if stats_var not in stats_updates_cache:
                    bprop_factor_shape = bprop_factor.get_shape()
                    batch_size = tf.shape(bprop_factor)[0]  # batch size
                    chan = int(bprop_factor_shape[-1])  # num channels
                    if op_type == 'Conv2D' or len(bprop_factor_shape) == 4:
                        if fprop_factor is not None:
                            if self._approx_t2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' % var.name))
                                bprop_factor = tf.reduce_sum(
                                    bprop_factor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bprop_factor = tf.reshape(
                                    bprop_factor, [-1, chan]) * operator_height * operator_width  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' % var.name))
                            bprop_factor = tf.reduce_sum(bprop_factor, [1, 2])

                    # assume sampled loss is averaged. TODO:figure out better
                    # way to handle this
                    bprop_factor *= tf.cast(batch_size, tf.float32)
                    ##

                    cov_b = tf.matmul(bprop_factor, bprop_factor,
                                      transpose_a=True) / tf.cast(tf.shape(bprop_factor)[0], tf.float32)

                    update_ops.append(cov_b)
                    stats_updates[stats_var] = cov_b
                    stats_updates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            a_key = list(stats_updates.keys())[0]
            stats_updates[a_key] = tf.Print(stats_updates[a_key], [tf.convert_to_tensor('step:'), self.global_step,
                                                                   tf.convert_to_tensor('computing stats')])
        self.stats_updates = stats_updates
        return stats_updates
Example #11
0
    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                                # factorization along the channels
                                # assume independence among input channels
                                # factor = B x 1 x 1 x (KH xKW x C)
                                # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(tf.reshape(
                                    fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
                                                               0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' % (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([fpropFactor, tf.ones(
                                [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat(
                                [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' % (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor, [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' % (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(
                        bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
                                          [tf.convert_to_tensor('step:'),
                                           self.global_step,
                                           tf.convert_to_tensor(
                                               'computing stats'),
                                           ])
        self.statsUpdates = statsUpdates
        return statsUpdates