def assign_preconditioner_to_host_vars(self): """Assign/Grab latest copy of preconditioners.""" keys_shapes_and_preconditioner_vars = [] assign_ops = [] for var in self._all_vars_for_preconditioning: shape = var.get_shape() if not self._fallback_to_diagonal_for_shape(shape): partitioned_v = TensorPartitioner.partition_tensor( var, self._partition_info) num_partitions = len(partitioned_v) for pt_idx, pt in enumerate(partitioned_v): pt_shape = pt.get_shape() preconditioner_exists_for_dim = ( self._preconditioner_available_for_dims(pt_shape)) var_rank = len(pt_shape) for i in range(var_rank): if preconditioner_exists_for_dim[i]: key = self._key_for_var(var, i, pt_idx) preconditioner = self.get_slot( var, self._preconditioner_key_for_partition_and_dim( i, pt_idx, num_partitions)) keys_shapes_and_preconditioner_vars.append( (key, tf.shape(preconditioner), preconditioner)) if not keys_shapes_and_preconditioner_vars: return tf.no_op() keys, shapes, preconditioner_vars = zip( *keys_shapes_and_preconditioner_vars) preconditioner_vals, successes = x_ops.get_preconditioners( shapes, keys=keys, preconditioner_compute_graphdef=( self._preconditioner_compute_graphdef)) for preconditioner_var, preconditioner_val, success in zip( preconditioner_vars, preconditioner_vals, successes): success_mult = tf.cast(success, preconditioner.dtype) assign_ops.append( state_ops.assign( preconditioner_var, (1.0 - success_mult) * preconditioner_var + success_mult * preconditioner_val)) return tf.group(*assign_ops)
def testPreconditioning(self): preconditioner_compute_graphdef = self.inverse_pth_root_graph() with tf.Session() as sess: global_step = tf.train.get_or_create_global_step() self.evaluate(tf.global_variables_initializer()) rand_input_1_t = np.random.rand(4, 4) rand_input_2_t = np.random.rand(4, 4) exponents = [-0.25, -0.25] symmetric_input_1_t = np.dot(rand_input_1_t, rand_input_1_t.transpose()) symmetric_input_2_t = np.dot(rand_input_2_t, rand_input_2_t.transpose()) outputs, statuses = ops.get_preconditioners( [tf.shape(symmetric_input_1_t), tf.shape(symmetric_input_2_t)], keys=['a', 'b'], preconditioner_compute_graphdef=preconditioner_compute_graphdef ) self.assertFalse(any(sess.run(statuses))) preconditioner = ops.compute_preconditioners( [symmetric_input_1_t, symmetric_input_2_t], exponents, tf.cast(global_step, tf.int32), keys=['a', 'b'], sync=True, preconditioner_compute_graphdef=preconditioner_compute_graphdef ) self.assertAllClose(outputs[0].eval(), np.zeros((4, 4)), atol=1e-4) self.assertAllClose(outputs[1].eval(), np.zeros((4, 4)), atol=1e-4) preconditioner.run() self.assertTrue(any(sess.run(statuses))) expected_output_1_t = self.inverse_pth_root( symmetric_input_1_t, exponents[0]) expected_output_2_t = self.inverse_pth_root( symmetric_input_2_t, exponents[1]) outputs_np = sess.run(outputs) self.assertAllClose(outputs_np[0], expected_output_1_t[0].eval(), atol=1e-1) self.assertAllClose(outputs_np[1], expected_output_2_t[0].eval(), atol=1e-1)