示例#1
0
def _dataset_reduce_fn(
    reduce_fn: _ReduceFnCallable,
    dataset: tf.data.Dataset,
    initial_state_fn: Callable[[], Any] = lambda: tf.constant(0)
) -> Any:
    return dataset.reduce(initial_state=initial_state_fn(),
                          reduce_func=reduce_fn)
示例#2
0
def update(dataset: tf.data.Dataset, state: State, message: server.Message,
           model_fn: Callable, optimizer_fn: Callable) -> Output:
    with tf.init_scope():
        model = model_fn(pos_weight=state.client_pos_weight)
        optimizer = optimizer_fn()

    message.model.assign_weights_to(model)

    def training_fn(num_examples, batch):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch, training=True)

        optimizer.apply_gradients(
            zip(tape.gradient(outputs.loss, model.trainable_variables),
                model.trainable_variables))

        return num_examples + outputs.num_examples

    client_weight = dataset.reduce(tf.constant(0, dtype=tf.int32), training_fn)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model.trainable_variables,
                                          message.model.trainable)

    return Output(weights_delta=weights_delta,
                  metrics=model.report_local_outputs(),
                  client_weight=tf.cast(client_weight, dtype=tf.float32),
                  client_state=State(
                      client_index=state.client_index,
                      client_pos_weight=state.client_pos_weight))
示例#3
0
def evaluate(dataset: tf.data.Dataset, state: State,
             weights: tff.learning.ModelWeights, coefficient_fn: Callable,
             model_fn: Callable) -> Evaluation:
    with tf.init_scope():
        mixing_coefficients = coefficient_fn()
        model = model_fn(pos_weight=state.client_pos_weight)

    tf.nest.map_structure(lambda v, t: v.assign(t), mixing_coefficients,
                          state.mixing_coefficients)
    __mix_weights(mixing_coefficients, state.model,
                  weights).assign_weights_to(model)

    def evaluation_fn(state, batch):
        outputs = model.forward_pass(batch, training=False)

        y_true = tf.reshape(batch[1], (-1, ))
        y_pred = tf.round(
            tf.nn.sigmoid(tf.reshape(outputs.predictions, (-1, ))))

        return state + tf.math.confusion_matrix(y_true, y_pred, num_classes=2)

    confusion_matrix = dataset.reduce(tf.zeros((2, 2), dtype=tf.int32),
                                      evaluation_fn)

    return Evaluation(confusion_matrix=confusion_matrix,
                      metrics=model.report_local_outputs())
示例#4
0
def get_unique_elements(dataset: tf.data.Dataset,
                        max_string_length: Optional[int] = None):
  """Gets the unique elements from the input `dataset`.

  The input `dataset` must yield batched rank-1 tensors. This function reads
  each coordinate of the tensor as an individual element and return unique
  elements.

  Args:
    dataset: A `tf.data.Dataset`. Element type must be `tf.string`.
    max_string_length: The maximum lenghth (in bytes) of strings in the dataset.
      Strings longer than `max_string_length` will be truncated. Defaults to
      `None`, which means there is no limit of the string length.

  Returns:
    A rank-1 Tensor containing the unique elements of the input dataset.

  Raises:
    ValueError:
      -- If the shape of elements in `dataset` is not rank 1.
      -- If `max_string_length` is not `None` and is less than 1.
    TypeError: If `dataset.element_spec.dtype` must be `tf.string` is not
      `tf.string`.
  """

  if dataset.element_spec.shape.rank != 1:
    raise ValueError('The shape of elements in `dataset` must be of rank 1, '
                     f' found rank = {dataset.element_spec.shape.rank}'
                     ' instead.')

  if max_string_length is not None and max_string_length < 1:
    raise ValueError('`max_string_length` must be at least 1 when it is not'
                     ' None.')

  if dataset.element_spec.dtype != tf.string:
    raise TypeError('`dataset.element_spec.dtype` must be `tf.string`, found'
                    f' element type {dataset.element_spec.dtype}')

  initial_list = tf.constant([], dtype=tf.string)

  def add_unique_element(element_list, element_batch):
    if max_string_length is not None:
      element_batch = tf.strings.substr(
          element_batch, 0, max_string_length, unit='BYTE')
    element_list = tf.concat([element_list, element_batch], axis=0)
    element_list, _ = tf.unique(element_list)
    return element_list

  unique_element_list = dataset.reduce(
      initial_state=initial_list, reduce_func=add_unique_element)

  return unique_element_list
示例#5
0
def to_stacked_tensor(ds: tf.data.Dataset) -> tf.Tensor:
  """Encodes the `tf.data.Dataset as stacked tensors.

  This is effectively the inverse of `tf.data.Dataset.from_tensor_slices()`.
  All elements from the input dataset are concatenated into a tensor structure,
  where the output structure matches the input `ds.element_spec`, and each
  output tensor will have the same shape plus one additional prefix dimension
  which elements are stacked in. For example, if the dataset contains  5
  elements with shape [3, 2], the returned tensor will have shape [5, 3, 2].
  Note that each element in the dataset could be as single tensor or a structure
  of tensors.

  Dataset elements must have fully-defined shapes. Any partially-defined element
  shapes will raise an error. If passing in a batched dataset, use
  `drop_remainder=True` to ensure the batched shape is fully defined.

  Args:
    ds: The input `tf.data.Dataset` to stack.

  Returns:
    A structure of tensors encoding the input dataset.

  Raises:
    ValueError: If any dataset element shape is not fully-defined.
  """
  py_typecheck.check_type(ds, tf.data.Dataset)

  def expanded_empty_tensor(tensor_spec: tf.TensorSpec) -> tf.Tensor:
    if not tensor_spec.shape.is_fully_defined():
      raise _TensorShapeNotFullyDefinedError()
    return tf.zeros(shape=[0] + tensor_spec.shape, dtype=tensor_spec.dtype)

  with tf.name_scope('to_stacked_tensor'):
    try:
      initial_state = tf.nest.map_structure(expanded_empty_tensor,
                                            ds.element_spec)
    except _TensorShapeNotFullyDefinedError as shape_not_defined_error:
      raise ValueError('Dataset elements must have fully-defined shapes. '
                       f'Found: {ds.element_spec}') from shape_not_defined_error

  @tf.function
  def append_tensor(stacked: tf.Tensor, tensor: tf.Tensor) -> tf.Tensor:
    expanded_tensor = tf.expand_dims(tensor, axis=0)
    return tf.concat((stacked, expanded_tensor), axis=0)

  @tf.function
  def reduce_func(old_state, input_element):
    tf.nest.assert_same_structure(old_state, input_element)
    return tf.nest.map_structure(append_tensor, old_state, input_element)

  return ds.reduce(initial_state, reduce_func)
示例#6
0
def _compute_kmeans_step(centroids: tf.Tensor, data: tf.data.Dataset):
    """Performs a k-means step on a dataset.

  This method finds, for each point in `data`, the closest centroid in
  `centroids`. It returns a structure `tff.learning.templates.ClientResult`
  whose `update` attribute is a tuple `(cluster_sums, cluster_weights)`. Here,
  `cluster_sums` is a tensor of shape matching `centroids`, where
  `cluster_sums[i, :]` is the sum of all points closest to the i-th centroid,
  and `cluster_weights` is a `(num_centroids,)` dimensional tensor whose i-th
  component is the number of points closest to the i-th centroid. The
  `ClientResult.update_weight` attribute is left empty.

  Args:
    centroids: A `tf.Tensor` of centroids, indexed by the first axis.
    data: A `tf.data.Dataset` of points, each of which has shape matching that
      of `centroids.shape[1:]`.

  Returns:
   A `tff.learning.templates.ClientResult`.
  """
    cluster_sums = tf.zeros_like(centroids)
    cluster_weights = tf.zeros(shape=(centroids.shape[0], ),
                               dtype=_WEIGHT_DTYPE)
    num_examples = tf.constant(0, dtype=_WEIGHT_DTYPE)

    def reduce_fn(state, point):
        cluster_sums, cluster_weights, num_examples = state
        closest_centroid = _find_closest_centroid(centroids, point)
        scatter_index = [[closest_centroid]]
        cluster_sums = tf.tensor_scatter_nd_add(cluster_sums, scatter_index,
                                                tf.expand_dims(point, axis=0))
        cluster_weights = tf.tensor_scatter_nd_add(cluster_weights,
                                                   scatter_index, [1])
        num_examples += 1
        return cluster_sums, cluster_weights, num_examples

    cluster_sums, cluster_weights, num_examples = data.reduce(
        initial_state=(cluster_sums, cluster_weights, num_examples),
        reduce_func=reduce_fn)

    stat_output = collections.OrderedDict(num_examples=num_examples)
    return client_works.ClientResult(update=(cluster_sums, cluster_weights),
                                     update_weight=()), stat_output
示例#7
0
def prepare_dataset(
    ds: tf.data.Dataset,
    batch_size: int,
    shuffle: bool = False,
    drop_remainder: bool = False,
):
    size_of_dataset = ds.reduce(0, lambda x, _: x + 1).numpy()
    if shuffle:
        ds = ds.shuffle(buffer_size=size_of_dataset, seed=SEED)
    ds: tf.data.Dataset = ds.batch(batch_size, drop_remainder=drop_remainder)

    @tf.function
    def prepare_data(features):
        image = tf.cast(features["image"], tf.float32)
        bs = tf.shape(image)[0]
        image = tf.reshape(image / 255.0, (bs, -1))
        return image, features["label"]

    autotune = tf.data.experimental.AUTOTUNE
    ds = ds.map(prepare_data, num_parallel_calls=autotune).prefetch(autotune)
    return ds
示例#8
0
def evaluate(dataset: tf.data.Dataset, state: State,
             weights: tff.learning.ModelWeights,
             model_fn: Callable) -> Evaluation:
    with tf.init_scope():
        model = model_fn(pos_weight=state.client_pos_weight)

    weights.assign_weights_to(model.base_model)
    state.model.assign_weights_to(model.personalized_model)

    def evaluation_fn(state, batch):
        outputs = model.forward_pass(batch, training=False)

        y_true = tf.reshape(batch[1], (-1, ))
        y_pred = tf.round(
            tf.nn.sigmoid(tf.reshape(outputs.predictions, (-1, ))))

        return state + tf.math.confusion_matrix(y_true, y_pred, num_classes=2)

    confusion_matrix = dataset.reduce(tf.zeros((2, 2), dtype=tf.int32),
                                      evaluation_fn)

    return Evaluation(confusion_matrix=confusion_matrix,
                      metrics=model.report_local_outputs())
示例#9
0
 def get_target_length(dataset: tf.data.Dataset) -> np.int64:
     return tf.cast(
         dataset.reduce(np.int32(0), lambda x, y: x + len(y[1])),
         np.int64)
示例#10
0
 def get_target_sum(dataset: tf.data.Dataset) -> np.int64:
     return dataset.reduce(np.int64(0),
                           lambda x, y: x + tf.math.reduce_sum(y[1]))
示例#11
0
    def dataset_split_fn(
            client_dataset: tf.data.Dataset,
            round_num: tf.Tensor) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """A `DatasetSplitFn` built with the given arguments.

    Args:
      client_dataset: `tf.data.Dataset` representing client data.
      round_num: Scalar tf.int64 tensor representing the 1-indexed round number
        during training. During evaluation, this is 0.

    Returns:
      A tuple of two `tf.data.Dataset`s, the first to be used for
      reconstruction, the second to be used post-reconstruction.
    """
        get_entry = lambda i, entry: entry
        if split_dataset:
            if split_dataset_strategy == SPLIT_STRATEGY_SKIP:

                def recon_condition(i, _):
                    return tf.equal(
                        tf.math.floormod(i, split_dataset_proportion), 0)

                def post_recon_condition(i, _):
                    return tf.greater(
                        tf.math.floormod(i, split_dataset_proportion), 0)

            elif split_dataset_strategy == SPLIT_STRATEGY_AGGREGATED:
                num_elements = client_dataset.reduce(
                    tf.constant(0.0, dtype=tf.float32), lambda x, _: x + 1)

                def recon_condition(i, _):
                    return i <= tf.cast(
                        num_elements / split_dataset_proportion,
                        dtype=tf.int64)

                def post_recon_condition(i, _):
                    return i > tf.cast(num_elements / split_dataset_proportion,
                                       dtype=tf.int64)

            else:
                raise ValueError(
                    'Unimplemented `split_dataset_strategy`: Must be one of '
                    '`{}`, or `{}`. Found {}'.format(
                        SPLIT_STRATEGY_SKIP, SPLIT_STRATEGY_AGGREGATED,
                        split_dataset_strategy))
        # split_dataset=False.
        else:
            recon_condition = lambda i, _: True
            post_recon_condition = lambda i, _: True

        recon_dataset = client_dataset.enumerate().filter(recon_condition).map(
            get_entry)
        post_recon_dataset = client_dataset.enumerate().filter(
            post_recon_condition).map(get_entry)

        # Number of reconstruction epochs is exactly recon_epochs_max if
        # recon_epochs_constant is True, and min(round_num, recon_epochs_max) if
        # not.
        num_recon_epochs = recon_epochs_max
        if not recon_epochs_constant:
            num_recon_epochs = tf.math.minimum(round_num, recon_epochs_max)

        # Apply `num_recon_epochs` before limiting to a maximum number of batches
        # if needed.
        recon_dataset = recon_dataset.repeat(num_recon_epochs)
        if recon_steps_max is not None:
            recon_dataset = recon_dataset.take(recon_steps_max)

        # Do the same for post-reconstruction.
        post_recon_dataset = post_recon_dataset.repeat(post_recon_epochs)
        if post_recon_steps_max is not None:
            post_recon_dataset = post_recon_dataset.take(post_recon_steps_max)

        return recon_dataset, post_recon_dataset
示例#12
0
def update(dataset: tf.data.Dataset, state: State, message: server.Message,
           coefficient_fn: Callable, model_fn: Callable,
           optimizer_fn: Callable) -> Output:
    with tf.init_scope():
        mixing_optimizer = optimizer_fn()
        mixing_coefficients = coefficient_fn()

        global_optimizer = optimizer_fn()
        global_model = model_fn(pos_weight=state.client_pos_weight)

        local_optimizer = optimizer_fn()
        local_model = model_fn(pos_weight=state.client_pos_weight)

        mixed_model = model_fn(pos_weight=state.client_pos_weight)

    tf.nest.map_structure(lambda v, t: v.assign(t), mixing_coefficients,
                          state.mixing_coefficients)
    message.model.assign_weights_to(global_model)
    state.model.assign_weights_to(local_model)
    __mix_weights(mixing_coefficients, state.model,
                  message.model).assign_weights_to(mixed_model)

    def training_fn(num_examples, batch):
        with tf.GradientTape() as global_tape:
            global_outputs = global_model.forward_pass(batch, training=True)

        with tf.GradientTape() as mixed_tape:
            mixed_outputs = mixed_model.forward_pass(batch, training=True)

        global_gradients = global_tape.gradient(
            global_outputs.loss, global_model.trainable_variables)
        mixed_gradients = mixed_tape.gradient(mixed_outputs.loss,
                                              mixed_model.trainable_variables)
        coefficient_gradients = __mixing_gradient(
            mixed_gradients, local_model.trainable_variables,
            global_model.trainable_variables)

        # Update global model
        global_optimizer.apply_gradients(
            zip(global_gradients, global_model.trainable_variables))

        # Update local model
        local_optimizer.apply_gradients(
            zip(mixed_gradients, local_model.trainable_variables))

        # Update coefficient
        mixing_optimizer.apply_gradients(
            zip(coefficient_gradients, mixing_coefficients))

        # Clip gradient to be within the interval [0, 1]
        tf.nest.map_structure(lambda v: v.assign(tf.clip_by_value(v, 0, 1)),
                              mixing_coefficients)

        __mix_weights(mixing_coefficients,
                      tff.learning.ModelWeights.from_model(local_model),
                      tff.learning.ModelWeights.from_model(
                          global_model)).assign_weights_to(mixed_model)

        return num_examples + global_outputs.num_examples

    client_weight = dataset.reduce(tf.constant(0, dtype=tf.int32), training_fn)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          global_model.trainable_variables,
                                          message.model.trainable)

    return Output(weights_delta=weights_delta,
                  client_weight=tf.cast(client_weight, dtype=tf.float32),
                  metrics=global_model.report_local_outputs(),
                  client_state=State(
                      client_index=state.client_index,
                      client_pos_weight=state.client_pos_weight,
                      model=tff.learning.ModelWeights.from_model(local_model),
                      mixing_coefficients=mixing_coefficients))
示例#13
0
    def __init__(self,
                 model_vars: ModelVarsGLM,
                 noise_model: str,
                 constraints_loc,
                 constraints_scale,
                 sample_indices=None,
                 data_set: tf.data.Dataset = None,
                 data_batch: tf.Tensor = None,
                 mode_jac="analytic",
                 mode_hessian="analytic",
                 mode_fim="analytic",
                 compute_a=True,
                 compute_b=True,
                 compute_jac=True,
                 compute_hessian=True,
                 compute_fim=True,
                 compute_ll=True):
        """ Return computational graph for jacobian based on mode choice.

        :param batched_data:
            Dataset iterator over mini-batches of data (used for training) or tf1.Tensor of mini-batch.
        :param sample_indices: Indices of samples to be used.
        :param constraints_loc: np.ndarray (constraints on mean model x mean model parameters)
            Constraints for location model.
            Array with constraints in rows and model parameters in columns.
            Each constraint contains non-zero entries for the a of parameters that
            has to sum to zero. This constraint is enforced by binding one parameter
            to the negative sum of the other parameters, effectively representing that
            parameter as a function of the other parameters. This dependent
            parameter is indicated by a -1 in this array, the independent parameters
            of that constraint (which may be dependent at an earlier constraint)
            are indicated by a 1.
        :param constraints_scale: np.ndarray (constraints on mean model x mean model parameters)
            Constraints for scale model.
            Array with constraints in rows and model parameters in columns.
            Each constraint contains non-zero entries for the a of parameters that
            has to sum to zero. This constraint is enforced by binding one parameter
            to the negative sum of the other parameters, effectively representing that
            parameter as a function of the other parameters. This dependent
            parameter is indicated by a -1 in this array, the independent parameters
            of that constraint (which may be dependent at an earlier constraint)
            are indicated by a 1.
        :param mode: str
            Mode by with which hessian is to be evaluated,
            "analytic" uses a closed form solution of the jacobian,
            "tf1" allows for evaluation of the jacobian via the tf1.gradients function.
        :param iterator: bool
            Whether an iterator or a tensor (single yield of an iterator) is given
            in.
        :param jac_a: bool
            Wether to compute Jacobian for a parameters. If both jac_a and jac_b are true,
            the entire jacobian is computed in self.jac.
        :param jac_b: bool
            Wether to compute Jacobian for b parameters. If both jac_a and jac_b are true,
            the entire jacobian is computed in self.jac.
        """
        assert data_set is None or data_batch is None

        self.noise_model = noise_model
        self.model_vars = model_vars
        self.constraints_loc = constraints_loc
        self.constraints_scale = constraints_scale

        self.compute_a = compute_a
        self.compute_b = compute_b

        self.mode_jac = mode_jac
        self.mode_hessian = mode_hessian
        self.mode_fim = mode_fim

        self.compute_jac = compute_jac
        self.compute_hessian = compute_hessian
        self.compute_fim_a = compute_fim and compute_a
        self.compute_fim_b = compute_fim and compute_b
        self.compute_ll = compute_ll

        n_var_all = self.model_vars.params.shape[0]
        n_var_a = self.model_vars.a_var.shape[0]
        n_var_b = self.model_vars.b_var.shape[0]
        dtype = self.model_vars.dtype
        self.dtype = dtype

        def map_fun(idx, data):
            return self.assemble_tensors(idx=idx, data=data)

        def init_fun():
            if self.compute_a and self.compute_b:
                n_var_train = n_var_all
            elif self.compute_a and not self.compute_b:
                n_var_train = n_var_a
            elif not self.compute_a and self.compute_b:
                n_var_train = n_var_b
            else:
                n_var_train = 0

            if self.compute_jac and n_var_train > 0:
                jac_init = tf.zeros([model_vars.n_features, n_var_train],
                                    dtype=dtype)
            else:
                jac_init = tf.zeros((), dtype=dtype)

            if self.compute_hessian and n_var_train > 0:
                hessian_init = tf.zeros(
                    [model_vars.n_features, n_var_train, n_var_train],
                    dtype=dtype)
            else:
                hessian_init = tf.zeros((), dtype=dtype)

            if self.compute_fim_a:
                fim_a_init = tf.zeros(
                    [model_vars.n_features, n_var_a, n_var_a], dtype=dtype)
            else:
                fim_a_init = tf.zeros((), dtype=dtype)
            if self.compute_fim_b:
                fim_b_init = tf.zeros(
                    [model_vars.n_features, n_var_b, n_var_b], dtype=dtype)
            else:
                fim_b_init = tf.zeros((), dtype=dtype)

            if self.compute_ll:
                ll_init = tf.zeros([model_vars.n_features], dtype=dtype)
            else:
                ll_init = tf.zeros((), dtype=dtype)

            return jac_init, hessian_init, fim_a_init, fim_b_init, ll_init

        def reduce_fun(old, new):
            return (tf.add(old[0], new[0]), tf.add(old[1], new[1]),
                    tf.add(old[2],
                           new[2]), tf.add(old[3],
                                           new[3]), tf.add(old[4], new[4]))

        if data_set is not None:
            set_op = data_set.reduce(initial_state=init_fun(),
                                     reduce_func=lambda old, new: reduce_fun(
                                         old, map_fun(new[0], new[1])))
            jac, hessian, fim_a, fim_b, ll = set_op
        elif data_batch is not None:
            set_op = map_fun(idx=sample_indices, data=data_batch)
            jac, hessian, fim_a, fim_b, ll = set_op
        else:
            raise ValueError("supply either data_set or data_batch")

        p_shape_a = self.model_vars.a_var.shape[
            0]  # This has to be _var to work with constraints.

        # With relay across tf1.Variable:
        # Containers and specific slices and transforms:
        if self.compute_a and self.compute_b:
            if self.compute_jac:
                self.jac = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_all], dtype=dtype),
                                       dtype=dtype)
                self.jac_a = self.jac[:, :p_shape_a]
                self.jac_b = self.jac[:, p_shape_a:]
            else:
                self.jac = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
                self.jac_a = self.jac
                self.jac_b = self.jac
            self.jac_train = self.jac

            if self.compute_hessian:
                self.hessian = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_all, n_var_all],
                    dtype=dtype),
                                           dtype=dtype)
                self.hessian_aa = self.hessian[:, :p_shape_a, :p_shape_a]
                self.hessian_bb = self.hessian[:, p_shape_a:, p_shape_a:]
            else:
                self.hessian = tf.Variable(tf.zeros((), dtype=dtype),
                                           dtype=dtype)
                self.hessian_aa = self.hessian
                self.hessian_bb = self.hessian
            self.hessian_train = self.hessian

            if self.compute_fim_a or self.compute_fim_b:
                self.fim_a = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_a, n_var_a],
                    dtype=dtype),
                                         dtype=dtype)
                self.fim_b = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_b, n_var_b],
                    dtype=dtype),
                                         dtype=dtype)
            else:
                self.fim_a = tf.Variable(tf.zeros((), dtype=dtype),
                                         dtype=dtype)
                self.fim_b = tf.Variable(tf.zeros((), dtype=dtype),
                                         dtype=dtype)
        elif self.compute_a and not self.compute_b:
            if self.compute_jac:
                self.jac = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_a], dtype=dtype),
                                       dtype=dtype)
                self.jac_a = self.jac
            else:
                self.jac = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
                self.jac_a = self.jac
            self.jac_b = None
            self.jac_train = self.jac_a

            if self.compute_hessian:
                self.hessian = tf.Variable(tf.zeros(
                    [model_vars.n_features, n_var_a, n_var_a], dtype=dtype),
                                           dtype=dtype)
                self.hessian_aa = self.hessian
            else:
                self.hessian = tf.Variable(tf.zeros((), dtype=dtype),
                                           dtype=dtype)
                self.hessian_aa = self.hessian
            self.hessian_bb = None
            self.hessian_train = self.hessian_aa

            if self.compute_fim_a:
                self.fim_a = tf.Variable(tf.zeros(
                    [model_vars.n_features, n_var_a, n_var_a], dtype=dtype),
                                         dtype=dtype)
            else:
                self.fim_a = tf.Variable(tf.zeros((), dtype=dtype),
                                         dtype=dtype)
            self.fim_b = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
        elif not self.compute_a and self.compute_b:
            if self.compute_jac:
                self.jac = tf.Variable(tf.zeros(
                    [self.model_vars.n_features, n_var_b], dtype=dtype),
                                       dtype=dtype)
                self.jac_b = self.jac
            else:
                self.jac = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
                self.jac_b = self.jac
            self.jac_a = None
            self.jac_train = self.jac_b

            if self.compute_hessian:
                self.hessian = tf.Variable(tf.zeros(
                    [model_vars.n_features, n_var_b, n_var_b], dtype=dtype),
                                           dtype=dtype)
                self.hessian_bb = self.hessian
            else:
                self.hessian = tf.Variable(tf.zeros((), dtype=dtype),
                                           dtype=dtype)
                self.hessian_bb = self.hessian
            self.hessian_aa = None
            self.hessian_train = self.hessian_bb

            self.fim_a = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
            if self.compute_fim_b:
                self.fim_b = tf.Variable(tf.zeros(
                    [model_vars.n_features, n_var_b, n_var_b], dtype=dtype),
                                         dtype=dtype)
            else:
                self.fim_b = tf.Variable(tf.zeros((), dtype=dtype),
                                         dtype=dtype)
        else:
            self.jac = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
            self.jac_a = None
            self.jac_b = None
            self.jac_train = None

            self.hessian = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
            self.hessian_aa = None
            self.hessian_bb = None
            self.hessian_train = None

            self.fim_a = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)
            self.fim_b = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)

        if self.compute_ll:
            self.ll = tf.Variable(tf.zeros([model_vars.n_features],
                                           dtype=dtype),
                                  dtype=dtype)
        else:
            self.ll = tf.Variable(tf.zeros((), dtype=dtype), dtype=dtype)

        self.neg_jac = tf.negative(self.jac) if self.jac is not None else None
        self.neg_jac_a = tf.negative(
            self.jac_a) if self.jac_a is not None else None
        self.neg_jac_b = tf.negative(
            self.jac_b) if self.jac_b is not None else None
        self.neg_jac_train = tf.negative(
            self.jac_train) if self.jac_train is not None else None

        self.neg_hessian = tf.negative(
            self.hessian) if self.hessian is not None else None
        self.neg_hessian_aa = tf.negative(
            self.hessian_aa) if self.hessian_aa is not None else None
        self.neg_hessian_bb = tf.negative(
            self.hessian_bb) if self.hessian_bb is not None else None
        self.neg_hessian_train = tf.negative(
            self.hessian_train) if self.hessian_train is not None else None

        self.neg_ll = tf.negative(self.ll) if self.ll is not None else None

        # Setting operation:
        jac_set = tf.compat.v1.assign(self.jac, jac)
        hessian_set = tf.compat.v1.assign(self.hessian, hessian)
        fim_a_set = tf.compat.v1.assign(self.fim_a, fim_a)
        fim_b_set = tf.compat.v1.assign(self.fim_b, fim_b)
        ll_set = tf.compat.v1.assign(self.ll, ll)

        self.set = tf.group(set_op, jac_set, hessian_set, fim_a_set, fim_b_set,
                            ll_set)
def _discretized_histogram_counts(client_data: tf.data.Dataset,
                                  lower_bound: float, upper_bound: float,
                                  num_bins: int) -> tf.Tensor:
    """Disretizes `client_data` and creates a histogram on the discretized data.

  Discretizes `client_data` by allocating records into `num_bins` bins between
  `lower_bound` and `upper_bound`. Data outside the range will be ignored.

  Args:
    client_data: A `tf.data.Dataset` containing the client-side records.
    lower_bound: A `float` specifying the lower bound of the data range.
    upper_bound: A `float` specifying the upper bound of the data range.
    num_bins: A `int`. The integer number of bins to compute.

  Returns:
    A `tf.Tensor` of shape `(num_bins,)` representing the histogram on
    discretized data.
  """

    if upper_bound < lower_bound:
        raise ValueError(f'upper_bound: {upper_bound} is smaller than '
                         f'lower_bound: {lower_bound}.')

    if num_bins <= 0:
        raise ValueError(f'num_bins: {num_bins} smaller or equal to zero.')

    data_type = client_data.element_spec.dtype

    if data_type != tf.float32:
        raise ValueError(f'`client_data` contains {data_type} values.'
                         f'`tf.float32` is expected.')

    precision = (upper_bound - lower_bound) / num_bins

    def insert_record(histogram, record):
        """Inserts a record to the histogram.

    If the record is outside the valid range, it will be dropped.

    Args:
      histogram: A `tf.Tensor` representing the histogram.
      record: A `float` representing the incoming record.

    Returns:
      A `tf.Tensor` representing updated histgoram with the input record
      inserted.
    """

        if histogram.shape != (num_bins, ):
            raise ValueError(f'Expected shape ({num_bins}, ). '
                             f'Get {histogram.shape}.')

        if record < lower_bound or record >= upper_bound:
            return histogram
        else:
            bin_index = tf.cast(
                tf.math.floor((record - lower_bound) / precision), tf.int32)
        return tf.tensor_scatter_nd_add(tensor=histogram,
                                        indices=[[bin_index]],
                                        updates=[1])

    histogram = client_data.reduce(tf.zeros([num_bins], dtype=tf.int32),
                                   insert_record)

    return histogram
示例#15
0
def _sum_dataset(dataset: tf.data.Dataset) -> int:
    """Returns the sum of all the integers in `dataset`."""
    return dataset.reduce(tf.cast(0, tf.int32), tf.add)
示例#16
0
def get_num_examples(dataset: tf.data.Dataset) -> int:
    """Returns the number of examples in the dataset by iterating over it once."""
    return dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy()