Ejemplo n.º 1
0
    def assign_preconditioner_to_host_vars(self):
        """Assign/Grab latest copy of preconditioners."""
        keys_shapes_and_preconditioner_vars = []
        assign_ops = []
        for var in self._all_vars_for_preconditioning:
            shape = var.get_shape()
            if not self._fallback_to_diagonal_for_shape(shape):
                partitioned_v = TensorPartitioner.partition_tensor(
                    var, self._partition_info)
                num_partitions = len(partitioned_v)
                for pt_idx, pt in enumerate(partitioned_v):
                    pt_shape = pt.get_shape()
                    preconditioner_exists_for_dim = (
                        self._preconditioner_available_for_dims(pt_shape))
                    var_rank = len(pt_shape)
                    for i in range(var_rank):
                        if preconditioner_exists_for_dim[i]:
                            key = self._key_for_var(var, i, pt_idx)
                            preconditioner = self.get_slot(
                                var,
                                self._preconditioner_key_for_partition_and_dim(
                                    i, pt_idx, num_partitions))
                            keys_shapes_and_preconditioner_vars.append(
                                (key, tf.shape(preconditioner),
                                 preconditioner))

            if not keys_shapes_and_preconditioner_vars:
                return tf.no_op()

            keys, shapes, preconditioner_vars = zip(
                *keys_shapes_and_preconditioner_vars)

            preconditioner_vals, successes = x_ops.get_preconditioners(
                shapes,
                keys=keys,
                preconditioner_compute_graphdef=(
                    self._preconditioner_compute_graphdef))

            for preconditioner_var, preconditioner_val, success in zip(
                    preconditioner_vars, preconditioner_vals, successes):
                success_mult = tf.cast(success, preconditioner.dtype)
                assign_ops.append(
                    state_ops.assign(
                        preconditioner_var,
                        (1.0 - success_mult) * preconditioner_var +
                        success_mult * preconditioner_val))
        return tf.group(*assign_ops)
 def testPreconditioning(self):
     preconditioner_compute_graphdef = self.inverse_pth_root_graph()
     with tf.Session() as sess:
         global_step = tf.train.get_or_create_global_step()
         self.evaluate(tf.global_variables_initializer())
         rand_input_1_t = np.random.rand(4, 4)
         rand_input_2_t = np.random.rand(4, 4)
         exponents = [-0.25, -0.25]
         symmetric_input_1_t = np.dot(rand_input_1_t,
                                      rand_input_1_t.transpose())
         symmetric_input_2_t = np.dot(rand_input_2_t,
                                      rand_input_2_t.transpose())
         outputs, statuses = ops.get_preconditioners(
             [tf.shape(symmetric_input_1_t),
              tf.shape(symmetric_input_2_t)],
             keys=['a', 'b'],
             preconditioner_compute_graphdef=preconditioner_compute_graphdef
         )
         self.assertFalse(any(sess.run(statuses)))
         preconditioner = ops.compute_preconditioners(
             [symmetric_input_1_t, symmetric_input_2_t],
             exponents,
             tf.cast(global_step, tf.int32),
             keys=['a', 'b'],
             sync=True,
             preconditioner_compute_graphdef=preconditioner_compute_graphdef
         )
         self.assertAllClose(outputs[0].eval(), np.zeros((4, 4)), atol=1e-4)
         self.assertAllClose(outputs[1].eval(), np.zeros((4, 4)), atol=1e-4)
         preconditioner.run()
         self.assertTrue(any(sess.run(statuses)))
         expected_output_1_t = self.inverse_pth_root(
             symmetric_input_1_t, exponents[0])
         expected_output_2_t = self.inverse_pth_root(
             symmetric_input_2_t, exponents[1])
         outputs_np = sess.run(outputs)
         self.assertAllClose(outputs_np[0],
                             expected_output_1_t[0].eval(),
                             atol=1e-1)
         self.assertAllClose(outputs_np[1],
                             expected_output_2_t[0].eval(),
                             atol=1e-1)