def Test(self): np.random.seed(1) if dtype_ in (np.float32, np.float64): x = np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) elif dtype == np.complex64: x = np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(np.float32) + 1j * np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(np.float32) else: x = np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(np.float64) + 1j * np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(np.float64) for compute_uv in False, True: for full_matrices in False, True: with self.test_session(): if x.ndim == 2: if compute_uv: tf_s, tf_u, tf_v = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: if compute_uv: tf_s, tf_u, tf_v = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) if compute_uv: np_u, np_s, np_v = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) else: np_s = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) CompareSingularValues(self, np_s, tf_s.eval()) if compute_uv: CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:])) CompareSingularVectors(self, np.conj(np.swapaxes(np_v, -2, -1)), tf_v.eval(), min(shape_[-2:])) CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices) CheckUnitary(self, tf_u) CheckUnitary(self, tf_v)
def Test(self): np.random.seed(1) x = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) if dtype_ == np.float32: atol = 1e-4 else: atol = 1e-14 for compute_uv in False, True: for full_matrices in False, True: with self.test_session(): if x.ndim == 2: if compute_uv: tf_s, tf_u, tf_v = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: if compute_uv: tf_s, tf_u, tf_v = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) if compute_uv: np_u, np_s, np_v = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) else: np_s = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) self.assertAllClose(np_s, tf_s.eval(), atol=atol) if compute_uv: CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:]), atol) CompareSingularVectors(self, np.swapaxes(np_v, -2, -1), tf_v.eval(), min(shape_[-2:]), atol) CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices, atol) CheckUnitary(self, tf_u) CheckUnitary(self, tf_v)
def Test(self): np.random.seed(1) x = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) if dtype_ == np.float32: atol = 1e-4 else: atol = 1e-14 for compute_uv in False, True: for full_matrices in False, True: with self.test_session(): if x.ndim == 2: if compute_uv: tf_s, tf_u, tf_v = tf.svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: if compute_uv: tf_s, tf_u, tf_v = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.batch_svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) if compute_uv: np_u, np_s, np_v = np.linalg.svd( x, compute_uv=compute_uv, full_matrices=full_matrices) else: np_s = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) self.assertAllClose(np_s, tf_s.eval(), atol=atol) if compute_uv: _CompareSingularVectors(self, np_u, tf_u.eval(), atol) _CompareSingularVectors(self, np.swapaxes(np_v, -2, -1), tf_v.eval(), atol)
def Test(self): np.random.seed(1) x = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) for compute_uv in False, True: for full_matrices in False, True: with self.test_session(): if x.ndim == 2: if compute_uv: tf_s, tf_u, tf_v = tf.svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: if compute_uv: tf_s, tf_u, tf_v = tf.batch_svd( tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) else: tf_s = tf.batch_svd(tf.constant(x), compute_uv=compute_uv, full_matrices=full_matrices) if compute_uv: np_u, np_s, np_v = np.linalg.svd( x, compute_uv=compute_uv, full_matrices=full_matrices) else: np_s = np.linalg.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) CompareSingularValues(self, np_s, tf_s.eval()) if compute_uv: CompareSingularVectors(self, np_u, tf_u.eval(), min(shape_[-2:])) CompareSingularVectors(self, np.swapaxes(np_v, -2, -1), tf_v.eval(), min(shape_[-2:])) CheckApproximation(self, x, tf_u, tf_s, tf_v, full_matrices) CheckUnitary(self, tf_u) CheckUnitary(self, tf_v)
def testWrongDimensions(self): # The input to svd should be 2-dimensional tensor. scalar = tf.constant(1.) with self.assertRaises(ValueError): tf.svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaises(ValueError): tf.svd(vector) tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]]) with self.assertRaises(ValueError): tf.svd(tensor) # The input to batch_svd should be a tensor of at least rank 2. scalar = tf.constant(1.) with self.assertRaises(ValueError): tf.batch_svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaises(ValueError): tf.batch_svd(vector)
def testWrongDimensions(self): # The input to svd should be 2-dimensional tensor. scalar = tf.constant(1.) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 0"): tf.svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 1"): tf.svd(vector) tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]]) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 3"): tf.svd(tensor) scalar = tf.constant(1. + 1.0j) with self.assertRaises(ValueError): tf.svd(scalar) vector = tf.constant([1. + 1.0j, 2. + 2.0j]) with self.assertRaises(ValueError): tf.svd(vector) tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]], [[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]]]) with self.assertRaises(ValueError): tf.svd(tensor) # The input to batch_svd should be a tensor of at least rank 2. scalar = tf.constant(1.) with self.assertRaisesRegexp( ValueError, "Shape must be at least rank 2 but is rank 0"): tf.batch_svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaisesRegexp( ValueError, "Shape must be at least rank 2 but is rank 1"): tf.batch_svd(vector) scalar = tf.constant(1. + 1.0j) with self.assertRaises(ValueError): tf.batch_svd(scalar) vector = tf.constant([1. + 1.0j, 2. + 2.0j]) with self.assertRaises(ValueError): tf.batch_svd(vector)
def testWrongDimensions(self): # The input to svd should be 2-dimensional tensor. scalar = tf.constant(1.) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 0"): tf.svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 1"): tf.svd(vector) tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]]) with self.assertRaisesRegexp(ValueError, "Shape must be rank 2 but is rank 3"): tf.svd(tensor) scalar = tf.constant(1. + 1.0j) with self.assertRaises(ValueError): tf.svd(scalar) vector = tf.constant([1. + 1.0j, 2. + 2.0j]) with self.assertRaises(ValueError): tf.svd(vector) tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]], [[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]]]) with self.assertRaises(ValueError): tf.svd(tensor) # The input to batch_svd should be a tensor of at least rank 2. scalar = tf.constant(1.) with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 2 but is rank 0"): tf.batch_svd(scalar) vector = tf.constant([1., 2.]) with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 2 but is rank 1"): tf.batch_svd(vector) scalar = tf.constant(1. + 1.0j) with self.assertRaises(ValueError): tf.batch_svd(scalar) vector = tf.constant([1. + 1.0j, 2. + 2.0j]) with self.assertRaises(ValueError): tf.batch_svd(vector)
def compute_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gs = gs factors = self.getFactors(gs, varlist) stats = self.getStats(factors, varlist) updateOps = [] statsUpdates = {} statsUpdates_cache = {} for var in varlist: opType = factors[var]['opName'] fops = factors[var]['op'] fpropFactor = factors[var]['fpropFactors_concat'] fpropStats_vars = stats[var]['fprop_concat_stats'] bpropFactor = factors[var]['bpropFactors_concat'] bpropStats_vars = stats[var]['bprop_concat_stats'] SVD_factors = {} for stats_var in fpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_fpropFactor = fpropFactor B = (tf.shape(fpropFactor)[0]) # batch size if opType == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] KH = int(convkernel_size[0]) KW = int(convkernel_size[1]) C = int(convkernel_size[2]) flatten_size = int(KH * KW * C) Oh = int(bpropFactor.get_shape()[1]) Ow = int(bpropFactor.get_shape()[2]) if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(SVD_factors) == 0: if KFAC_DEBUG: print(('approx %s act factor with rank-1 SVD factors' % (var.name))) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd(tf.reshape( fpropFactor, [-1, KH * KW, C])) # get rank-1 approx slides sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW full_factor_shape = fpropFactor.get_shape() patches_k.set_shape( [full_factor_shape[0], KH * KW]) patches_c = V[:, :, 0] * sqrtS1 # B x C patches_c.set_shape([full_factor_shape[0], C]) SVD_factors[C] = patches_c SVD_factors[KH * KW] = patches_k fpropFactor = SVD_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[ 0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approxT2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % (var.name))) # T^2 terms * 1/T^2, size: B x C fpropFactor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fpropFactor = tf.reshape( patches, [-1, flatten_size]) / Oh / Ow fpropFactor_size = int(fpropFactor.get_shape()[-1]) if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: if opType == 'Conv2D' and not self._approxT2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fpropFactor = tf.concat([fpropFactor, tf.ones( [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1) else: # use homogeneous coordinates fpropFactor = tf.concat( [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fpropFactor, fpropFactor, transpose_a=True) / tf.cast(B, tf.float32) updateOps.append(cov) statsUpdates[stats_var] = cov if opType != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers statsUpdates_cache[stats_var] = cov for stats_var in bpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_bpropFactor = bpropFactor bpropFactor_shape = bpropFactor.get_shape() B = tf.shape(bpropFactor)[0] # batch size C = int(bpropFactor_shape[-1]) # num channels if opType == 'Conv2D' or len(bpropFactor_shape) == 4: if fpropFactor is not None: if self._approxT2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum( bpropFactor, [1, 2]) # T^2 terms * 1/T^2 else: bpropFactor = tf.reshape( bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum(bpropFactor, [1, 2]) # assume sampled loss is averaged. TO-DO:figure out better # way to handle this bpropFactor *= tf.to_float(B) ## cov_b = tf.matmul( bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0]) updateOps.append(cov_b) statsUpdates[stats_var] = cov_b statsUpdates_cache[stats_var] = cov_b if KFAC_DEBUG: aKey = list(statsUpdates.keys())[0] statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor( 'computing stats'), ]) self.statsUpdates = statsUpdates return statsUpdates
def compute_stats(self, loss_sampled, var_list=None): """ compute the stats values :param loss_sampled: ([TensorFlow Tensor]) the loss function output :param var_list: ([TensorFlow Tensor]) The parameters :return: ([TensorFlow Tensor]) stats updates """ varlist = var_list if varlist is None: varlist = tf.trainable_variables() gradient_sampled = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gradient_sampled = gradient_sampled # remove unused variables gradient_sampled, varlist = zip(*[(grad, var) for (grad, var) in zip(gradient_sampled, varlist) if grad is not None]) factors = self.get_factors(gradient_sampled, varlist) stats = self.get_stats(factors, varlist) update_ops = [] stats_updates = {} stats_updates_cache = {} for var in varlist: op_type = factors[var]['opName'] fops = factors[var]['op'] fprop_factor = factors[var]['fpropFactors_concat'] fprop_stats_vars = stats[var]['fprop_concat_stats'] bprop_factor = factors[var]['bpropFactors_concat'] bprop_stats_vars = stats[var]['bprop_concat_stats'] svd_factors = {} for stats_var in fprop_stats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in stats_updates_cache: batch_size = (tf.shape(fprop_factor)[0]) # batch size if op_type == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] kernel_height = int(convkernel_size[0]) kernel_width = int(convkernel_size[1]) chan = int(convkernel_size[2]) flatten_size = int(kernel_height * kernel_width * chan) operator_height = int(bprop_factor.get_shape()[1]) operator_width = int(bprop_factor.get_shape()[2]) if operator_height == 1 and operator_width == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(svd_factors) == 0: if KFAC_DEBUG: print(('approx %s act factor with rank-1 SVD factors' % var.name)) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd(tf.reshape( fprop_factor, [-1, kernel_height * kernel_width, chan])) # get rank-1 approx slides sqrt_s1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrt_s1 # B x KH*KW full_factor_shape = fprop_factor.get_shape() patches_k.set_shape( [full_factor_shape[0], kernel_height * kernel_width]) patches_c = V[:, :, 0] * sqrt_s1 # B x C patches_c.set_shape([full_factor_shape[0], chan]) svd_factors[chan] = patches_c svd_factors[kernel_height * kernel_width] = patches_k fprop_factor = svd_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches(fprop_factor, ksizes=[1, convkernel_size[ 0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approx_t2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % var.name)) # T^2 terms * 1/T^2, size: B x C fprop_factor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fprop_factor = tf.reshape( patches, [-1, flatten_size]) / operator_height / operator_width fprop_factor_size = int(fprop_factor.get_shape()[-1]) if stats_var_dim == (fprop_factor_size + 1) and not self._blockdiag_bias: if op_type == 'Conv2D' and not self._approx_t2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fprop_factor = tf.concat([fprop_factor, tf.ones( [tf.shape(fprop_factor)[0], 1]) / operator_height / operator_width], 1) else: # use homogeneous coordinates fprop_factor = tf.concat( [fprop_factor, tf.ones([tf.shape(fprop_factor)[0], 1])], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fprop_factor, fprop_factor, transpose_a=True) / tf.cast(batch_size, tf.float32) update_ops.append(cov) stats_updates[stats_var] = cov if op_type != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers stats_updates_cache[stats_var] = cov for stats_var in bprop_stats_vars: if stats_var not in stats_updates_cache: bprop_factor_shape = bprop_factor.get_shape() batch_size = tf.shape(bprop_factor)[0] # batch size chan = int(bprop_factor_shape[-1]) # num channels if op_type == 'Conv2D' or len(bprop_factor_shape) == 4: if fprop_factor is not None: if self._approx_t2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % var.name)) bprop_factor = tf.reduce_sum( bprop_factor, [1, 2]) # T^2 terms * 1/T^2 else: bprop_factor = tf.reshape( bprop_factor, [-1, chan]) * operator_height * operator_width # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % var.name)) bprop_factor = tf.reduce_sum(bprop_factor, [1, 2]) # assume sampled loss is averaged. TODO:figure out better # way to handle this bprop_factor *= tf.cast(batch_size, tf.float32) ## cov_b = tf.matmul(bprop_factor, bprop_factor, transpose_a=True) / tf.cast(tf.shape(bprop_factor)[0], tf.float32) update_ops.append(cov_b) stats_updates[stats_var] = cov_b stats_updates_cache[stats_var] = cov_b if KFAC_DEBUG: a_key = list(stats_updates.keys())[0] stats_updates[a_key] = tf.Print(stats_updates[a_key], [tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('computing stats')]) self.stats_updates = stats_updates return stats_updates