Esempio n. 1
0
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']

            # For minibatch class reweighting, initialize per-batch loss function
            if class_reweight_mode == 'minibatch':
                batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(
                    labels=labels)
            else:
                batch_loss_fn = loss_fn

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                  y_pred=logits,
                                  from_logits=True))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the BN parameters and bias terms. This
                    # excludes only fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'bn' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses)
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                loss = negative_log_likelihood + l2_loss + kl_loss

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            probs = tf.squeeze(tf.nn.sigmoid(logits))

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].add_batch(probs, label=labels)
    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']

      # For minibatch class reweighting, initialize per-batch loss function
      if class_reweight_mode == 'minibatch':
        batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels)
      else:
        batch_loss_fn = loss_fn

      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        if FLAGS.use_bfloat16:
          logits = tf.cast(logits, tf.float32)

        negative_log_likelihood = tf.reduce_mean(
            batch_loss_fn(
                y_true=tf.expand_dims(labels, axis=-1),
                y_pred=logits,
                from_logits=True))
        l2_loss = sum(model.losses)
        loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      probs = tf.squeeze(tf.nn.sigmoid(logits))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auprc'].update_state(labels, probs)
      metrics['train/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].add_batch(probs, label=labels)
Esempio n. 3
0
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            if FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
                labels = tf.tile(labels, [FLAGS.ensemble_size])

            # For minibatch class reweighting, initialize per-batch loss function
            if class_reweight_mode == 'minibatch':
                batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(
                    labels=labels)
            else:
                batch_loss_fn = loss_fn

            with tf.GradientTape() as tape:
                if FLAGS.num_mc_samples_train > 1:
                    # Pythonic Implem
                    logits_list = []
                    for _ in range(FLAGS.num_mc_samples_train):
                        logits = model(images, training=True)
                        logits = tf.squeeze(logits, axis=-1)
                        if FLAGS.use_bfloat16:
                            logits = tf.cast(logits, tf.float32)

                        logits_list.append(logits)

                    # Logits dimension is (num_samples, batch_size).
                    logits_list = tf.stack(logits_list, axis=0)

                    probs_list = tf.nn.sigmoid(logits_list)
                    probs = tf.reduce_mean(probs_list, axis=0)
                    negative_log_likelihood = tf.reduce_mean(
                        batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                      y_pred=probs,
                                      from_logits=False))
                else:
                    # Single train step
                    logits = model(images, training=True)
                    if FLAGS.use_bfloat16:
                        logits = tf.cast(logits, tf.float32)

                    negative_log_likelihood = tf.reduce_mean(
                        batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                      y_pred=logits,
                                      from_logits=True))
                    probs = tf.squeeze(tf.nn.sigmoid(logits))

                l2_loss = compute_l2_loss(model)
                kl = sum(model.losses) / train_dataset_size
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync
                # elbo = -(negative_log_likelihood + l2_loss + kl)

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weights. This excludes BN
                    # and slow weights, but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].add_batch(probs, label=labels)
    def step_fn(inputs):
      """Per-replica step function."""
      images = inputs['features']
      labels = inputs['labels']

      # For minibatch class reweighting, initialize per-batch loss function
      if class_reweight_mode == 'minibatch':
        print('Retracing loss fn retrieval')
        batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels)
      else:
        batch_loss_fn = loss_fn

      with tf.GradientTape() as tape:
        # TODO(nband): TPU-friendly implem
        if FLAGS.num_mc_samples_train > 1:
          logits_arr = tf.TensorArray(
              tf.float32, size=FLAGS.num_mc_samples_train)

          for i in tf.range(FLAGS.num_mc_samples_train):
            logits = model(images, training=True)
            # logits = tf.squeeze(logits, axis=-1)
            # if FLAGS.use_bfloat16:
            #   logits = tf.cast(logits, tf.float32)

            logits_arr = logits_arr.write(i, logits)

          logits_list = logits_arr.stack()

          # if FLAGS.num_mc_samples_train > 1:
          #   # Pythonic Implem
          #   logits_list = []
          #   for _ in range(FLAGS.num_mc_samples_train):
          #     print('Tracing for loop')
          #     logits = model(images, training=True)
          #     if FLAGS.use_bfloat16:
          #       print('tracing bfloat conditional')
          #       logits = tf.cast(logits, tf.float32)
          #
          #     logits = tf.squeeze(logits, axis=-1)
          #     logits_list.append(logits)
          #
          #   # Logits dimension is (num_samples, batch_size).
          #   logits_list = tf.stack(logits_list, axis=0)

          probs_list = tf.nn.sigmoid(logits_list)
          probs = tf.reduce_mean(probs_list, axis=0)
          negative_log_likelihood = tf.reduce_mean(
              batch_loss_fn(
                  y_true=tf.expand_dims(labels, axis=-1),
                  y_pred=probs,
                  from_logits=False))
        else:
          # Single train step
          logits = model(images, training=True)
          if FLAGS.use_bfloat16:
            logits = tf.cast(logits, tf.float32)
          negative_log_likelihood = tf.reduce_mean(
              batch_loss_fn(
                  y_true=tf.expand_dims(labels, axis=-1),
                  y_pred=logits,
                  from_logits=True))
          probs = tf.squeeze(tf.nn.sigmoid(logits))

        filtered_variables = []
        for var in model.trainable_variables:
          # Apply l2 on the BN parameters and bias terms. This
          # excludes only fast weight approximate posterior/prior parameters,
          # but pay caution to their naming scheme.
          if 'bn' in var.name or 'bias' in var.name:
            filtered_variables.append(tf.reshape(var, (-1,)))

        l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
            tf.concat(filtered_variables, axis=0))
        kl = sum(model.losses)
        kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
        kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs
        kl_scale = tf.minimum(1., kl_scale)
        kl_loss = kl_scale * kl

        loss = negative_log_likelihood + l2_loss + kl_loss

        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        scaled_loss = loss / strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

      metrics['train/loss'].update_state(loss)
      metrics['train/negative_log_likelihood'].update_state(
          negative_log_likelihood)
      metrics['train/kl'].update_state(kl)
      metrics['train/kl_scale'].update_state(kl_scale)
      metrics['train/accuracy'].update_state(labels, probs)
      metrics['train/auprc'].update_state(labels, probs)
      metrics['train/auroc'].update_state(labels, probs)

      if not use_tpu:
        metrics['train/ece'].add_batch(probs, label=labels)