def train(model, optimizer, dataset, bar): """ Description: Performs training for one epoch """ if CONFIDENCE_INTERVAL == 0.95: Z = 1.96 elif CONFIDENCE_INTERVAL == 0.99: Z = 2.58 total_subword_loss = 0.0 total_token_loss = 0.0 total_batches = 0 token_losses = [] for (batch, (hole_window, hole_target, seq_len_hole_target)) in enumerate(dataset): hole_window = tf.squeeze(hole_window, axis=1) hole_target = tf.squeeze(hole_target, axis=1) seq_len_hole_target = tf.squeeze(seq_len_hole_target, axis=1) with tf.GradientTape() as g: batch_token_loss, masked_loss = losses.hole_loss( model, hole_window, hole_target, seq_len_hole_target, True) grads = g.gradient(batch_token_loss, model.trainable_variables) optimizer.apply_gradients( losses.clip_gradients(zip(grads, model.trainable_variables))) batch_subword_loss = tf.reduce_mean( tf.reduce_sum(masked_loss, axis=1) / tf.cast(seq_len_hole_target, dtype=tf.float32)) total_subword_loss += batch_subword_loss.numpy() total_token_loss += batch_token_loss.numpy() token_losses.append(batch_token_loss.numpy()) total_batches += 1 if total_batches % 10 == 0: bar.update(10) postfix = OrderedDict( batch_loss={batch_token_loss.experimental_ref()}) bar.set_postfix(postfix) # Calculate mean batch_wise losses subword_loss = total_subword_loss / total_batches token_loss = total_token_loss / total_batches subword_ppl = np.exp(subword_loss) token_ppl = np.exp(token_loss) # confidence interval error error = Z * np.sqrt(np.var(token_losses) / total_batches) return subword_loss, token_loss, error
def evaluate(model, dataset, bar): """ Description: Performs evaluation for one epoch """ if CONFIDENCE_INTERVAL == 0.95: Z = 1.96 elif CONFIDENCE_INTERVAL == 0.99: Z = 2.58 total_subword_loss = 0.0 total_token_loss = 0.0 total_batches = 0 token_losses = [] for (batch, (hole_window, hole_target, seq_len_hole_target)) in enumerate(dataset): hole_window = tf.squeeze(hole_window, axis=1) hole_target = tf.squeeze(hole_target, axis=1) seq_len_hole_target = tf.squeeze(seq_len_hole_target, axis=1) batch_token_loss, masked_loss = losses.hole_loss( model, hole_window, hole_target, seq_len_hole_target, False) batch_subword_loss = tf.reduce_mean( tf.reduce_sum(masked_loss, axis=1) / tf.cast(seq_len_hole_target, dtype=tf.float32)) total_subword_loss += batch_subword_loss.numpy() total_token_loss += batch_token_loss.numpy() token_losses.append(batch_token_loss.numpy()) total_batches += 1 if total_batches % 10 == 0: bar.update(10) postfix = OrderedDict( batch_loss={batch_token_loss.experimental_ref()}) bar.set_postfix(postfix) # Calculate mean batch_wise losses subword_loss = total_subword_loss / total_batches token_loss = total_token_loss / total_batches # confidence interval error error = Z * np.sqrt(np.var(token_losses) / total_batches) return subword_loss, token_loss, error
def evaluate(model, dataset, method, bar, inner_learning_rate, sup_batch_size, num_of_updates): """ Description: Calculates the average token hole target cross-entropy across the dataset after performing inner updates using support set for each hole target. Both the optimizer state and model parameters are reset before calculating the next hole target cross-entropy. """ if CONFIDENCE_INTERVAL == 0.95: Z = 1.96 elif CONFIDENCE_INTERVAL == 0.99: Z = 2.58 token_losses = [] # dict with key = hole features, value = hole loss hole_features = {} total_subword_loss = 0.0 total_token_loss = 0.0 total_batches = 0 # storing initial weights of the base model so that they can be restored later trained_model_trainable_variables = [] for entry in model.get_weights(): trained_model_trainable_variables.append(entry) for (batch, (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token, seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset): #Reset before evaluating each hole target y = tf.reshape(tf.Variable(1, dtype=tf.int32), (1, 1)) model(y, y, False) model.set_weights(trained_model_trainable_variables) if sup_flag: sup_window = tf.squeeze(sup_window, axis=0) sup_token = tf.squeeze(sup_token, axis=0) seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0) if sup_flag and (method == 'tssa' or method == 'dyn_eval'): # Get the new model object after performing num_of_updates inner updates model_new = losses.inner_loss_eval(model, sup_window, sup_token, seq_len_sup_token, False, method, inner_learning_rate, sup_batch_size, num_of_updates) # Calculate the hole target loss using the new model object batch_token_loss, masked_loss = losses.hole_loss( model_new, hole_window, hole_target, seq_len_hole_target, False) # If there are no support tokens found in the file or if the evaluation mode is for a base_model, directly calculate the hole target loss using the hole window if not sup_flag or method == 'base_model': batch_token_loss, masked_loss = losses.hole_loss( model, hole_window, hole_target, seq_len_hole_target, False) batch_subword_loss = tf.reduce_mean( tf.reduce_sum(masked_loss, axis=1) / tf.cast(seq_len_hole_target, dtype=tf.float32)) token_loss = batch_token_loss.numpy() total_subword_loss += batch_subword_loss.numpy() total_token_loss += token_loss token_losses.append(token_loss) hole_features[hole_identity.numpy()] = token_loss total_batches += 1 if total_batches % 10 == 0: bar.update(10) postfix = OrderedDict(batch_loss={token_loss}) bar.set_postfix(postfix) # Calculate mean batch_wise losses subword_loss = total_subword_loss / total_batches token_loss = total_token_loss / total_batches # confidence interval error error = Z * np.sqrt(np.var(token_losses) / total_batches) return subword_loss, token_loss, error, hole_features
def train(model, optimizer_inner, optimizer_outer, dataset, train_method, bar, epsilon_reptile, batch_size_sup, num_of_updates): """ Description: Performs meta-training for one epoch """ if CONFIDENCE_INTERVAL == 0.95: Z = 1.96 elif CONFIDENCE_INTERVAL == 0.99: Z = 2.58 token_losses = [] #During training, we want different holes to be sampled from the same file across epochs. So we do not want a fixed seed np.random.seed(None) total_subword_loss = 0.0 total_token_loss = 0.0 total_batches = 0 for (batch, (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token, seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset): if sup_flag: sup_window = tf.squeeze(sup_window, axis=0) sup_token = tf.squeeze(sup_token, axis=0) seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0) #Storing weights for use in reptile later old_model_trainable_variables = [] for entry in model.get_weights(): old_model_trainable_variables.append(entry) # Get the new model instance after doing num_of_updates inner updates and then calculate the gradient of the hole loss w.r.t the updated parameters to give the outer update if sup_flag and train_method=='fomaml': model_new = losses.support_loss_train(model, sup_window, sup_token, seq_len_sup_token, True, optimizer_inner, batch_size_sup, num_of_updates) with tf.GradientTape() as g: batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, True) grads = g.gradient(batch_token_loss, model.trainable_variables) optimizer_outer.apply_gradients(losses.clip_gradients(zip(grads, model.trainable_variables))) # Get the new model instance after doing num_of_updates inner updates and then calculate the outer update of reptile if sup_flag and train_method=='reptile': model_new = losses.support_loss_train(model, sup_window, sup_token, seq_len_sup_token, True, optimizer_inner, batch_size_sup, num_of_updates) batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, True) new_weights = [] for i in range(len(model_new.trainable_variables)): new_weights.append(old_model_trainable_variables[i] + epsilon_reptile*(model_new.trainable_variables[i]-old_model_trainable_variables[i])) model.set_weights(new_weights) # If there are no support tokens found in the file directly calculate the gradient of the hole target loss if not sup_flag: with tf.GradientTape() as g: batch_token_loss, masked_loss = losses.hole_loss(model, hole_window, hole_target, seq_len_hole_target, True) grads = g.gradient(batch_token_loss, model.trainable_variables) optimizer_outer.apply_gradients(evaluation_new.clip_gradients(zip(grads, model.trainable_variables))) batch_subword_loss = tf.reduce_mean(tf.reduce_sum(masked_loss, axis=1)/ tf.cast(seq_len_hole_target, dtype=tf.float32)) token_loss = batch_token_loss.numpy() total_subword_loss += batch_subword_loss.numpy() total_token_loss += token_loss token_losses.append(token_loss) total_batches += 1 if total_batches%10==0: bar.update(10) postfix = OrderedDict(batch_loss={batch_token_loss.experimental_ref()}) bar.set_postfix(postfix) # Calculate mean batch_wise losses subword_loss = total_subword_loss/ total_batches token_loss = total_token_loss/ total_batches # confidence interval error error = Z*np.sqrt(np.var(token_losses)/total_batches) return subword_loss, token_loss, error
def evaluate(model, dataset, bar, inner_learning_rate, sup_batch_size, num_of_updates): """ Description: Performs evaluation for one epoch. Both the optimizer state and model parameters are reset before calculating the next hole target cross-entropy. """ if CONFIDENCE_INTERVAL == 0.95: Z = 1.96 elif CONFIDENCE_INTERVAL == 0.99: Z = 2.58 np.random.seed(42) token_losses = [] hole_features = {} total_subword_loss = 0.0 total_token_loss = 0.0 total_batches = 0 trained_model_trainable_variables = [] for entry in model.get_weights(): trained_model_trainable_variables.append(entry) for (batch, (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token, seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset): #Reset before evaluating each hole target y = tf.reshape(tf.Variable(1, dtype=tf.int32), (1,1)) model(y, y, False) model.set_weights(trained_model_trainable_variables) if sup_flag: sup_window = tf.squeeze(sup_window, axis=0) sup_token = tf.squeeze(sup_token, axis=0) seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0) model_new = losses.inner_loss_eval(model, sup_window, sup_token, seq_len_sup_token, False, 'tssa', inner_learning_rate, sup_batch_size, num_of_updates) batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, False) # If there are no support tokens found in the file directly calculate the hole target loss using the hole window if not sup_flag: batch_token_loss, masked_loss = losses.hole_loss(model, hole_window, hole_target, seq_len_hole_target, False) batch_subword_loss = tf.reduce_mean(tf.reduce_sum(masked_loss, axis=1)/ tf.cast(seq_len_hole_target, dtype=tf.float32)) token_loss = batch_token_loss.numpy() total_subword_loss += batch_subword_loss.numpy() total_token_loss += token_loss token_losses.append(token_loss) hole_features[hole_identity.numpy()]=token_loss total_batches += 1 if total_batches%10==0: bar.update(10) postfix = OrderedDict(batch_loss={batch_token_loss.experimental_ref()}) bar.set_postfix(postfix) #For training for next epoch model.set_weights(trained_model_trainable_variables) # Calculate mean batch_wise losses subword_loss = total_subword_loss/ total_batches token_loss = total_token_loss/ total_batches # confidence interval error error = Z * np.sqrt(np.var(token_losses)/total_batches) return subword_loss, token_loss, error, hole_features