예제 #1
0
    def _build(self, data, is_meta_training=True, test_data=None):
        if isinstance(data, list):
            if self.limited:
                data = data_module.ProblemInstance(*data)
            else:
                data = data_unlimited.ProblemInstance(*data)

        if test_data is not None:
            test_latent = self.encoder(test_data.tr_input)

            tr_loss, tr_acc = self.least_square(test_data,
                                                test_latent,
                                                use_val=False)

        self.is_meta_training = is_meta_training
        self.save_problem_instance_stats(data.tr_input)

        latents = self.encoder(data.tr_input)
        val_loss, val_accuracy = self.least_square(data, latents)

        batch_val_loss = tf.reduce_mean(val_loss)
        if self.limited:
            batch_val_loss *= data.weight
        batch_val_accuracy = tf.reduce_mean(val_accuracy)
        regularization_penalty = self._l2_regularization

        if test_data is not None:
            batch_val_loss += tr_loss * self.t_weight

        return batch_val_loss, regularization_penalty, batch_val_accuracy
예제 #2
0
    def _build(self, data, is_meta_training=True):
        """Connects the LEO module to the graph, creating the variables.

    Args:
      data: A data_module.ProblemInstance constaining Tensors with the
          following shapes:
          - tr_input: (N, K, dim)
          - tr_output: (N, K, 1)
          - tr_info: (N, K)
          - val_input: (N, K_valid, dim)
          - val_output: (N, K_valid, 1)
          - val_info: (N, K_valid)
            where N is the number of classes (as in N-way) and K and the and
            K_valid are numbers of training and validation examples within a
            problem instance correspondingly (as in K-shot), and dim is the
            dimensionality of the embedding.
      is_meta_training: A boolean describing whether we run in the training
        mode.

    Returns:
      Tensor with the inner validation loss of LEO (include both adaptation in
      the latent space and finetuning).
    """
        if isinstance(data, list):
            data = data_module.ProblemInstance(*data)

        self.is_meta_training = is_meta_training
        self.save_problem_instance_stats(data.tr_input)

        latents, kl = self.forward_encoder(data)
        # tr_loss, adapted_classifier_weights, encoder_penalty = self.leo_inner_loop(
        #     data, latents)
        tr_loss, adapted_classifier_weights = self.forward_decoder(
            data, latents)

        # val_loss, val_accuracy = self.finetuning_inner_loop(
        #     data, tr_loss, adapted_classifier_weights)
        val_loss, val_accuracy = self.calculate_inner_loss(
            data.val_input, data.val_output, adapted_classifier_weights)

        # val_loss += self._kl_weight * kl
        # val_loss += self._encoder_penalty_weight * encoder_penalty

        # regularization_penalty = (
        #    self._l2_regularization + self._decoder_orthogonality_reg)
        regularization_penalty = tf.constant(0, dtype=self._float_dtype)

        print(val_loss)
        print(val_accuracy)
        batch_val_loss = tf.reduce_mean(val_loss)
        batch_val_accuracy = tf.reduce_mean(val_accuracy)
        print(batch_val_loss)
        print(batch_val_accuracy)
        quit()

        return batch_val_loss + regularization_penalty, batch_val_accuracy
예제 #3
0
  def _build(self, data, is_meta_training=True):
    """Connects the LEO module to the graph, creating the variables.

    Args:
      data: A data_module.ProblemInstance constaining Tensors with the
          following shapes:
          - tr_input: (N, K, dim)
          - tr_output: (N, K, 1)
          - tr_info: (N, K)
          - val_input: (N, K_valid, dim)
          - val_output: (N, K_valid, 1)
          - val_info: (N, K_valid)
            where N is the number of classes (as in N-way) and K and the and
            K_valid are numbers of training and validation examples within a
            problem instance correspondingly (as in K-shot), and dim is the
            dimensionality of the embedding.
      is_meta_training: A boolean describing whether we run in the training
        mode.

    Returns:
      Tensor with the inner validation loss of LEO (include both adaptation in
      the latent space and finetuning).
    """
    if isinstance(data, list):
      data = data_module.ProblemInstance(*data)
    self.is_meta_training = is_meta_training
    self.save_problem_instance_stats(data.tr_input)

    latents, kl = self.forward_encoder(data)
    tr_loss, adapted_classifier_weights, encoder_penalty = self.leo_inner_loop(
        data, latents)
    # print(encoder_penalty)

    val_loss, val_accuracy, val_output = self.finetuning_inner_loop(
        data, tr_loss, adapted_classifier_weights)

    val_loss += self._kl_weight * kl
    val_loss += self._encoder_penalty_weight * encoder_penalty
    # The l2 regularization is is already added to the graph when constructing
    # the snt.Linear modules. We pass the orthogonality regularizer separately,
    # because it is not used in self.grads_and_vars.
    regularization_penalty = (
        self._l2_regularization + self._decoder_orthogonality_reg)

    batch_val_loss = tf.reduce_mean(val_loss)
    batch_val_accuracy = tf.reduce_mean(val_accuracy)

    additional_loss = self._kl_weight * kl + self._encoder_penalty_weight * encoder_penalty + regularization_penalty
    if self._deconfound:
      # return batch_val_loss + regularization_penalty, batch_val_accuracy, val_output
      return batch_val_loss + regularization_penalty, additional_loss, batch_val_accuracy, val_output
    else:
      return batch_val_loss + regularization_penalty, batch_val_accuracy
예제 #4
0
def _random_problem_instance(num_classes=7,
                             num_examples_per_class=5,
                             embedding_dim=17,
                             use_64bits_dtype=True):
    inputs_dtype = tf.float64 if use_64bits_dtype else tf.float32
    inputs = tf.constant(np.random.random(
        (num_classes, num_examples_per_class, embedding_dim)),
                         dtype=inputs_dtype)
    outputs_dtype = tf.int64 if use_64bits_dtype else tf.int32
    outputs = tf.constant(np.random.randint(low=0,
                                            high=num_classes,
                                            size=(num_classes,
                                                  num_examples_per_class, 1)),
                          dtype=outputs_dtype)
    problem = data.ProblemInstance(tr_input=inputs,
                                   val_input=inputs,
                                   tr_info=inputs,
                                   tr_output=outputs,
                                   val_output=outputs,
                                   val_info=inputs)
    return problem
예제 #5
0
    def test_inner_loop_adaptation(self):
        problem_instance = data.ProblemInstance(
            tr_input=constant_float64([[[4.]]]),
            tr_output=tf.constant([[[0]]], dtype=tf.int64),
            tr_info=[],
            val_input=[],
            val_output=[],
            val_info=[],
        )
        # encoder = decoder = id
        # predict returns classifier_weights**2 * inputs = latents**2 * inputs
        # loss = id = inputs*latents
        # dl/dlatent = 2 * latent * inputs
        # 4 -> 4 - 0.1 * 2 * 4 * 4 = 0.8
        # 0.8 -> 0.8 - 0.1 * 2 * 0.8 * 4 = 0.16
        # 0.16 -> 0.16 - 0.1 * 2 * 0.16 * 4 = 0.032

        # is_meta_training=False disables kl and encoder penalties
        adapted_parameters, _ = self._leo(problem_instance,
                                          is_meta_training=False)

        with self.session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertAllClose(sess.run(adapted_parameters), 0.032)
예제 #6
0
파일: model.py 프로젝트: fzohra/despurold
    def _build(self, data, is_meta_training=True):
        """Connects the LEO module to the graph, creating the variables.

    Args:
      data: A data_module.ProblemInstance constaining Tensors with the
          following shapes:
          - tr_input: (N, K, dim)
          - tr_output: (N, K, 1)
          - tr_info: (N, K)
          - val_input: (N, K_valid, dim)
          - val_output: (N, K_valid, 1)
          - val_info: (N, K_valid)
            where N is the number of classes (as in N-way) and K and the and
            K_valid are numbers of training and validation examples within a
            problem instance correspondingly (as in K-shot), and dim is the
            dimensionality of the embedding.
      is_meta_training: A boolean describing whether we run in the training
        mode.

    Returns:
      Tensor with the inner validation loss of LEO (include both adaptation in
      the latent space and finetuning).
    """
        if isinstance(data, list):
            data = data_module.ProblemInstance(*data)
        self.is_meta_training = is_meta_training
        self.save_problem_instance_stats(data.tr_input)

        latents, kl, kl_components, kl_zn, distribution_params = self.forward_encoder(
            data)
        tr_loss, adapted_classifier_weights, encoder_penalty, corr_penalty, adapted_latents, adapted_kl, adapted_kl_components, adapted_kl_zn, spurious = self.leo_inner_loop(
            data, latents, distribution_params)

        val_loss, val_accuracy = self.finetuning_inner_loop(
            data, tr_loss, adapted_classifier_weights)

        #tr_loss can we observe this for each latent component
        #val_loss can we observe this for each latent component
        #compute generalization_loss = val_loss - tr_loss
        #if generalization_loss is high fir a latent component, simply threshold and drop it.
        # graph the generalization loss for the components during the training, are there any that have a high genrealizatio loss

        #remove correlations between latent space gradient dimensions
        val_loss += self._kl_weight * kl
        val_loss += self._encoder_penalty_weight * encoder_penalty
        # The l2 regularization is is already added to the graph when constructing
        # the snt.Linear modules. We pass the orthogonality regularizer separately,
        # because it is not used in self.grads_and_vars.
        regularization_penalty = (self._l2_regularization +
                                  self._decoder_orthogonality_reg)

        batch_val_loss = tf.reduce_mean(val_loss)
        batch_val_accuracy = tf.reduce_mean(val_accuracy)
        batch_generalization_loss = tf.reshape(tf.reduce_mean(val_loss, 1),
                                               [5, 1]) - tr_loss

        if self.is_meta_training:
            tr_out = tf.cast(data.tr_output, dtype=tf.float32)
            tr_out_tiled = tf.tile(tr_out, multiples=[1, 1, 64])
            tr_out_tiled_expanded = tf.expand_dims(tr_out_tiled, -1)
            kl_components_y = tf.concat(
                [tr_out_tiled_expanded,
                 tf.expand_dims(kl_components, -1)],
                axis=-1)
            adapted_kl_components_y = tf.concat([
                tr_out_tiled_expanded,
                tf.expand_dims(adapted_kl_components, -1)
            ],
                                                axis=-1)
            kl_zn_y = tf.concat([tf.squeeze(tr_out, -1), kl_zn], axis=-1)
            adapted_kl_zn_y = tf.concat(
                [tf.squeeze(tr_out, -1), adapted_kl_zn], axis=-1)
            latents_y = tf.concat(
                [tr_out_tiled_expanded,
                 tf.expand_dims(latents, -1)], axis=-1)
            adapted_latents_y = tf.concat(
                [tr_out_tiled_expanded,
                 tf.expand_dims(adapted_latents, -1)],
                axis=-1)
            spurious_y = tf.concat(
                [tr_out_tiled_expanded,
                 tf.expand_dims(spurious, -1)], axis=-1)
        else:
            kl_components_y = kl_components
            adapted_kl_components_y = adapted_kl_components
            kl_zn_y = kl_zn
            adapted_kl_zn_y = adapted_kl_zn
            latents_y = latents
            adapted_latents_y = adapted_latents
            spurious_y = spurious
        return batch_val_loss + regularization_penalty, batch_val_accuracy, batch_generalization_loss, \
               kl_components_y, adapted_kl_components_y, kl_zn_y, adapted_kl_zn_y, kl, adapted_kl, latents_y, adapted_latents_y, spurious_y