def invoke_async_preconditioner_computation(self, global_step_int32): """Invokes SVD preconditioner and graph runs on the CPU.""" keys_stats_and_rank = [] for var in self._all_vars_for_preconditioning: shape = var.get_shape() if not self._fallback_to_diagonal_for_shape(shape): partitioned_v = TensorPartitioner.partition_tensor( var, self._partition_info) num_partitions = len(partitioned_v) for pt_idx, pt_v in enumerate(partitioned_v): pt_v_shape = pt_v.get_shape() preconditioner_exists_for_dim = ( self._preconditioner_available_for_dims(pt_v_shape)) for i in range(len(pt_v_shape)): if preconditioner_exists_for_dim[i]: rank = sum(preconditioner_exists_for_dim) key = self._key_for_var(var, i, pt_idx) stat = self.get_slot( var, self._statistics_key_for_partition_and_dim( i, pt_idx, num_partitions)) keys_stats_and_rank.append((key, stat, rank)) if not keys_stats_and_rank: return tf.no_op() keys, stats, ranks = zip(*keys_stats_and_rank) return x_ops.compute_preconditioners( stats, [-1.0 / (2.0 * r) for r in ranks], global_step_int32, keys=keys, sync=self._synchronous_preconditioning, preconditioner_compute_graphdef=self. _preconditioner_compute_graphdef)
def testPreconditioning(self): preconditioner_compute_graphdef = self.inverse_pth_root_graph() with tf.Session() as sess: global_step = tf.train.get_or_create_global_step() self.evaluate(tf.global_variables_initializer()) rand_input_1_t = np.random.rand(4, 4) rand_input_2_t = np.random.rand(4, 4) exponents = [-0.25, -0.25] symmetric_input_1_t = np.dot(rand_input_1_t, rand_input_1_t.transpose()) symmetric_input_2_t = np.dot(rand_input_2_t, rand_input_2_t.transpose()) outputs, statuses = ops.get_preconditioners( [tf.shape(symmetric_input_1_t), tf.shape(symmetric_input_2_t)], keys=['a', 'b'], preconditioner_compute_graphdef=preconditioner_compute_graphdef ) self.assertFalse(any(sess.run(statuses))) preconditioner = ops.compute_preconditioners( [symmetric_input_1_t, symmetric_input_2_t], exponents, tf.cast(global_step, tf.int32), keys=['a', 'b'], sync=True, preconditioner_compute_graphdef=preconditioner_compute_graphdef ) self.assertAllClose(outputs[0].eval(), np.zeros((4, 4)), atol=1e-4) self.assertAllClose(outputs[1].eval(), np.zeros((4, 4)), atol=1e-4) preconditioner.run() self.assertTrue(any(sess.run(statuses))) expected_output_1_t = self.inverse_pth_root( symmetric_input_1_t, exponents[0]) expected_output_2_t = self.inverse_pth_root( symmetric_input_2_t, exponents[1]) outputs_np = sess.run(outputs) self.assertAllClose(outputs_np[0], expected_output_1_t[0].eval(), atol=1e-1) self.assertAllClose(outputs_np[1], expected_output_2_t[0].eval(), atol=1e-1)