def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn( labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'bn' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1, ))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl loss = negative_log_likelihood + l2_loss + kl_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels)
def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) l2_loss = sum(model.losses) loss = negative_log_likelihood + (FLAGS.l2 * l2_loss) # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) probs = tf.squeeze(tf.nn.sigmoid(logits)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels)
def step_fn(inputs): """Per-Replica StepFn.""" images = inputs['features'] labels = inputs['labels'] if FLAGS.ensemble_size > 1: images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) labels = tf.tile(labels, [FLAGS.ensemble_size]) # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': batch_loss_fn = utils.get_minibatch_reweighted_loss_fn( labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: if FLAGS.num_mc_samples_train > 1: # Pythonic Implem logits_list = [] for _ in range(FLAGS.num_mc_samples_train): logits = model(images, training=True) logits = tf.squeeze(logits, axis=-1) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) logits_list.append(logits) # Logits dimension is (num_samples, batch_size). logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) negative_log_likelihood = tf.reduce_mean( batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1), y_pred=probs, from_logits=False)) else: # Single train step logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) l2_loss = compute_l2_loss(model) kl = sum(model.losses) / train_dataset_size kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl # Scale the loss given the TPUStrategy will reduce sum all gradients. loss = negative_log_likelihood + l2_loss + kl_loss scaled_loss = loss / strategy.num_replicas_in_sync # elbo = -(negative_log_likelihood + l2_loss + kl) grads = tape.gradient(scaled_loss, model.trainable_variables) # Separate learning rate implementation. if FLAGS.fast_weight_lr_multiplier != 1.0: grads_and_vars = [] for grad, var in zip(grads, model.trainable_variables): # Apply different learning rate on the fast weights. This excludes BN # and slow weights, but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): grads_and_vars.append( (grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels)
def step_fn(inputs): """Per-replica step function.""" images = inputs['features'] labels = inputs['labels'] # For minibatch class reweighting, initialize per-batch loss function if class_reweight_mode == 'minibatch': print('Retracing loss fn retrieval') batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(labels=labels) else: batch_loss_fn = loss_fn with tf.GradientTape() as tape: # TODO(nband): TPU-friendly implem if FLAGS.num_mc_samples_train > 1: logits_arr = tf.TensorArray( tf.float32, size=FLAGS.num_mc_samples_train) for i in tf.range(FLAGS.num_mc_samples_train): logits = model(images, training=True) # logits = tf.squeeze(logits, axis=-1) # if FLAGS.use_bfloat16: # logits = tf.cast(logits, tf.float32) logits_arr = logits_arr.write(i, logits) logits_list = logits_arr.stack() # if FLAGS.num_mc_samples_train > 1: # # Pythonic Implem # logits_list = [] # for _ in range(FLAGS.num_mc_samples_train): # print('Tracing for loop') # logits = model(images, training=True) # if FLAGS.use_bfloat16: # print('tracing bfloat conditional') # logits = tf.cast(logits, tf.float32) # # logits = tf.squeeze(logits, axis=-1) # logits_list.append(logits) # # # Logits dimension is (num_samples, batch_size). # logits_list = tf.stack(logits_list, axis=0) probs_list = tf.nn.sigmoid(logits_list) probs = tf.reduce_mean(probs_list, axis=0) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=probs, from_logits=False)) else: # Single train step logits = model(images, training=True) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( batch_loss_fn( y_true=tf.expand_dims(labels, axis=-1), y_pred=logits, from_logits=True)) probs = tf.squeeze(tf.nn.sigmoid(logits)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the BN parameters and bias terms. This # excludes only fast weight approximate posterior/prior parameters, # but pay caution to their naming scheme. if 'bn' in var.name or 'bias' in var.name: filtered_variables.append(tf.reshape(var, (-1,))) l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) kl = sum(model.losses) kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype) kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs kl_scale = tf.minimum(1., kl_scale) kl_loss = kl_scale * kl loss = negative_log_likelihood + l2_loss + kl_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/kl'].update_state(kl) metrics['train/kl_scale'].update_state(kl_scale) metrics['train/accuracy'].update_state(labels, probs) metrics['train/auprc'].update_state(labels, probs) metrics['train/auroc'].update_state(labels, probs) if not use_tpu: metrics['train/ece'].add_batch(probs, label=labels)