def training_step(x, y): x1, x2 = x lossdict = {} # target projections targ1 = tf.nn.l2_normalize(target(x1, training=False), 1) targ2 = tf.nn.l2_normalize(target(x2, training=False), 1) with tf.GradientTape() as tape: # online projections z1 = online(x1, training=True) z2 = online(x2, training=True) # online predictions pred1 = tf.nn.l2_normalize(prediction(z1, training=True), 1) pred2 = tf.nn.l2_normalize(prediction(z2, training=True), 1) # compute mean-squared error both ways mse_loss = _mse(targ1, pred2) + _mse(targ2, pred1) #mse_loss = _dot_product_loss(targ1, pred2) + _dot_product_loss(targ2, pred1) lossdict["loss"] = mse_loss lossdict["mse_loss"] = mse_loss if weight_decay > 0: lossdict["l2_loss"] = compute_l2_loss(online) + \ compute_l2_loss(prediction) lossdict["loss"] += weight_decay * lossdict["l2_loss"] # UPDATE WEIGHTS OF ONLINE MODEL gradients = tape.gradient(lossdict["loss"], trainvars) optimizer.apply_gradients(zip(gradients, trainvars)) # UPDATE WEIGHTS OF TARGET MODEL #lossdict["target_online_avg_weight_diff"] = exponential_model_update(target, online, tau) return lossdict
def teststep(img_batch, text_batch): img_embed = img_model(img_batch, training=False) text_embed = text_model(text_batch, training=False) nce_loss, acc = compute_nce_loss(img_embed, text_embed, temp, True) if weight_decay > 0: l2_loss = compute_l2_loss(img_model) + compute_l2_loss(text_model) else: l2_loss = 0 loss = nce_loss + weight_decay * l2_loss return loss, acc
def train_step(x, y): with tf.GradientTape() as tape: total_loss = 0 lossdict = {} outputs = model(x, training=True) # hack to make this work with a single task if len(tasks) == 1: outputs = [outputs] for pred, y_true, weight, task in zip(outputs, y, task_loss_weights, tasks): task_loss = masked_sparse_categorical_crossentropy( y_true, pred) lossdict["loss_" + task] = task_loss if adaptive: # interpret weight as log(sigma^2). Kendall's paper mentions # that they use this as it's more numerically stable inv_sig_sq = tf.math.exp(-1 * weight) total_loss += task_loss * inv_sig_sq + 0.5 * weight else: total_loss += weight * task_loss if weight_decay > 0: lossdict["l2_loss"] = compute_l2_loss(model) total_loss += weight_decay * lossdict["l2_loss"] if (distill_func is not None) & (distill_weight > 0): lossdict["distill_loss"] = distill_func(outputs, x) total_loss += distill_weight * lossdict["distill_loss"] lossdict["total_loss"] = total_loss gradients = tape.gradient(total_loss, trainvars) optimizer.apply_gradients(zip(gradients, trainvars)) return lossdict
def trainstep(img_batch, text_batch): with tf.GradientTape() as tape: img_embed = img_model(img_batch, training=True) text_embed = text_model(text_batch, training=True) nce_loss = compute_nce_loss(img_embed, text_embed, temp) if weight_decay > 0: l2_loss = compute_l2_loss(img_model) + compute_l2_loss( text_model) else: l2_loss = 0 loss = nce_loss + weight_decay * l2_loss grads = tape.gradient(loss, trainvars) optimizer.apply_gradients(zip(grads, trainvars)) lossdict = {"loss": loss, "l2_loss": l2_loss, "nce_loss": nce_loss} return lossdict
def test_l2_loss(): # build a simple model inpt = tf.keras.layers.Input(3) net = tf.keras.layers.Dense(5)(inpt) model = tf.keras.Model(inpt, net) # check the L2 loss; should be a scalar bigger than zero assert compute_l2_loss(model).numpy() > 0 # should also work with multiple inputs assert compute_l2_loss(model, model).numpy() > 0 # set all trainable weights to zro new_weights = [ np.zeros(x.shape, dtype=np.float32) for x in model.get_weights() ] model.set_weights(new_weights) assert compute_l2_loss(model).numpy() == 0. # add a batch norm layer- shouldn't change anything # because the function should skip that layer net = tf.keras.layers.BatchNormalization()(net) model = tf.keras.Model(inpt, net) assert compute_l2_loss(model).numpy() == 0.
def training_step(x,y): with tf.GradientTape() as tape: z = model(x, training=True) loss = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(y, z) ) if weight_decay > 0: loss += weight_decay*compute_l2_loss(model) # update fast model variables = model.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return {"loss":loss}
def _step(x, y): with tf.GradientTape() as tape: loss = 0 # get replica context- we'll use this to aggregate embeddings # across different GPUs context = tf.distribute.get_replica_context() # run images through model and normalize embeddings z1 = tf.nn.l2_normalize(model(x, training=True), 1) z2 = tf.nn.l2_normalize(model(y, training=True), 1) # aggregate projections across replicas. z1 and z2 should # now correspond to the global batch size (gbs, d) z1 = context.all_gather(z1, 0) z2 = context.all_gather(z2, 0) with tape.stop_recording(): gbs = z1.shape[0] mask = _build_negative_mask(gbs) # SimCLR case if (tau_plus == 0) & (beta == 0): softmax_prob, nce_batch_acc = _simclr_softmax_prob( z1, z2, temp, mask) # HCL case elif (tau_plus > 0) & (beta > 0): softmax_prob, nce_batch_acc = _hcl_softmax_prob( z1, z2, temp, beta, tau_plus, mask) else: assert False, "both tau_plus and beta must be nonzero to run HCL" softmax_loss = tf.reduce_mean(-1 * tf.math.log(softmax_prob)) loss += softmax_loss if weight_decay > 0: l2_loss = compute_l2_loss(model) loss += weight_decay * l2_loss else: l2_loss = 0 grad = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grad, model.trainable_variables)) return { "loss": loss, "nt_xent_loss": softmax_loss, "l2_loss": l2_loss, "nce_batch_accuracy": nce_batch_acc }
def training_step(x, y): with tf.GradientTape() as tape: # run images through model and normalize embeddings z1 = tf.nn.l2_normalize(embed_model(x, training=True), 1) z2 = tf.nn.l2_normalize(embed_model(y, training=True), 1) if not data_parallel: # get replica context- we'll use this to aggregate embeddings # across different GPUs context = tf.distribute.get_replica_context() # aggregate projections across replicas. z1 and z2 should # now correspond to the global batch size (gbs, d) z1 = context.all_gather(z1, 0) z2 = context.all_gather(z2, 0) xent_loss, batch_acc = _contrastive_loss(z1, z2, temperature, decoupled, eps) if weight_decay > 0: l2_loss = compute_l2_loss(embed_model) else: l2_loss = 0 loss = xent_loss + weight_decay * l2_loss gradients = tape.gradient(loss, embed_model.trainable_variables) optimizer.apply_gradients( zip(gradients, embed_model.trainable_variables)) return { "nt_xent_loss": xent_loss, "l2_loss": l2_loss, "loss": loss, "nce_batch_acc": batch_acc }
def train_step(lab, unlab): x, y = lab x_unlab_wk, x_unlab_str = unlab with tf.GradientTape() as tape: # semisupervised case if lam > 0: # concatenate labeled/pseudolabeled batches N = x.shape[0] mu_N = x_unlab_str.shape[0] x_batch = tf.concat([x, x_unlab_wk, x_unlab_str], 0) pred_batch = model(x_batch, training=True) # then split the labeled/pseudolabeled pieces preds = pred_batch[:N, :] wk_preds = pred_batch[N:N + mu_N, :] str_preds = pred_batch[N + mu_N:, :] # GENERATE FIXMATCH PSEUDOLABELS # round predictions to pseudolabels pseudolabels = tf.cast(wk_preds > 0.5, tf.float32) # also compute a mask from the predictions, # since we only incorporate high-confidence cases, # compute a mask that's 1 every place that's close # to 1 or 0 mask = _build_mask(wk_preds, tau) # let's try keeping track of how accurate these # predictions are ssl_acc = tf.reduce_mean( tf.cast( tf.cast(str_preds > 0.5, tf.float32) == pseudolabels, tf.float32)) crossent_tensor = K.binary_crossentropy( pseudolabels, str_preds) fixmatch_loss = tf.reduce_mean(mask * crossent_tensor) else: fixmatch_loss = 0 ssl_acc = -1 mask = -1 preds = model(x, training=True) trainloss = tf.reduce_mean(K.binary_crossentropy(y, preds)) if weight_decay > 0: l2_loss = compute_l2_loss(model) else: l2_loss = 0 total_loss = trainloss + lam * fixmatch_loss + weight_decay * l2_loss # compute and apply gradients gradients = tape.gradient(total_loss, trainvars) optimizer.apply_gradients(zip(gradients, trainvars)) return { "total_loss": total_loss, "supervised_loss": trainloss, "fixmatch_loss": fixmatch_loss, "l2_loss": l2_loss, "fixmatch_prediction_accuracy": ssl_acc, "fixmatch_mask_fraction": tf.reduce_mean(mask) }