Ejemplo n.º 1
0
    def invoke_async_preconditioner_computation(self, global_step_int32):
        """Invokes SVD preconditioner and graph runs on the CPU."""
        keys_stats_and_rank = []
        for var in self._all_vars_for_preconditioning:
            shape = var.get_shape()
            if not self._fallback_to_diagonal_for_shape(shape):
                partitioned_v = TensorPartitioner.partition_tensor(
                    var, self._partition_info)
                num_partitions = len(partitioned_v)
                for pt_idx, pt_v in enumerate(partitioned_v):
                    pt_v_shape = pt_v.get_shape()
                    preconditioner_exists_for_dim = (
                        self._preconditioner_available_for_dims(pt_v_shape))
                    for i in range(len(pt_v_shape)):
                        if preconditioner_exists_for_dim[i]:
                            rank = sum(preconditioner_exists_for_dim)
                            key = self._key_for_var(var, i, pt_idx)
                            stat = self.get_slot(
                                var,
                                self._statistics_key_for_partition_and_dim(
                                    i, pt_idx, num_partitions))
                            keys_stats_and_rank.append((key, stat, rank))

        if not keys_stats_and_rank:
            return tf.no_op()
        keys, stats, ranks = zip(*keys_stats_and_rank)

        return x_ops.compute_preconditioners(
            stats, [-1.0 / (2.0 * r) for r in ranks],
            global_step_int32,
            keys=keys,
            sync=self._synchronous_preconditioning,
            preconditioner_compute_graphdef=self.
            _preconditioner_compute_graphdef)
 def testPreconditioning(self):
     preconditioner_compute_graphdef = self.inverse_pth_root_graph()
     with tf.Session() as sess:
         global_step = tf.train.get_or_create_global_step()
         self.evaluate(tf.global_variables_initializer())
         rand_input_1_t = np.random.rand(4, 4)
         rand_input_2_t = np.random.rand(4, 4)
         exponents = [-0.25, -0.25]
         symmetric_input_1_t = np.dot(rand_input_1_t,
                                      rand_input_1_t.transpose())
         symmetric_input_2_t = np.dot(rand_input_2_t,
                                      rand_input_2_t.transpose())
         outputs, statuses = ops.get_preconditioners(
             [tf.shape(symmetric_input_1_t),
              tf.shape(symmetric_input_2_t)],
             keys=['a', 'b'],
             preconditioner_compute_graphdef=preconditioner_compute_graphdef
         )
         self.assertFalse(any(sess.run(statuses)))
         preconditioner = ops.compute_preconditioners(
             [symmetric_input_1_t, symmetric_input_2_t],
             exponents,
             tf.cast(global_step, tf.int32),
             keys=['a', 'b'],
             sync=True,
             preconditioner_compute_graphdef=preconditioner_compute_graphdef
         )
         self.assertAllClose(outputs[0].eval(), np.zeros((4, 4)), atol=1e-4)
         self.assertAllClose(outputs[1].eval(), np.zeros((4, 4)), atol=1e-4)
         preconditioner.run()
         self.assertTrue(any(sess.run(statuses)))
         expected_output_1_t = self.inverse_pth_root(
             symmetric_input_1_t, exponents[0])
         expected_output_2_t = self.inverse_pth_root(
             symmetric_input_2_t, exponents[1])
         outputs_np = sess.run(outputs)
         self.assertAllClose(outputs_np[0],
                             expected_output_1_t[0].eval(),
                             atol=1e-1)
         self.assertAllClose(outputs_np[1],
                             expected_output_2_t[0].eval(),
                             atol=1e-1)