Ejemplo n.º 1
0
def train(model, optimizer, dataset, bar):
    """
  Description: Performs training for one epoch
  """
    if CONFIDENCE_INTERVAL == 0.95:
        Z = 1.96
    elif CONFIDENCE_INTERVAL == 0.99:
        Z = 2.58

    total_subword_loss = 0.0
    total_token_loss = 0.0
    total_batches = 0
    token_losses = []

    for (batch, (hole_window, hole_target,
                 seq_len_hole_target)) in enumerate(dataset):

        hole_window = tf.squeeze(hole_window, axis=1)
        hole_target = tf.squeeze(hole_target, axis=1)
        seq_len_hole_target = tf.squeeze(seq_len_hole_target, axis=1)

        with tf.GradientTape() as g:
            batch_token_loss, masked_loss = losses.hole_loss(
                model, hole_window, hole_target, seq_len_hole_target, True)

        grads = g.gradient(batch_token_loss, model.trainable_variables)
        optimizer.apply_gradients(
            losses.clip_gradients(zip(grads, model.trainable_variables)))

        batch_subword_loss = tf.reduce_mean(
            tf.reduce_sum(masked_loss, axis=1) /
            tf.cast(seq_len_hole_target, dtype=tf.float32))
        total_subword_loss += batch_subword_loss.numpy()
        total_token_loss += batch_token_loss.numpy()
        token_losses.append(batch_token_loss.numpy())
        total_batches += 1

        if total_batches % 10 == 0:
            bar.update(10)
            postfix = OrderedDict(
                batch_loss={batch_token_loss.experimental_ref()})
            bar.set_postfix(postfix)

    # Calculate mean batch_wise losses
    subword_loss = total_subword_loss / total_batches
    token_loss = total_token_loss / total_batches
    subword_ppl = np.exp(subword_loss)
    token_ppl = np.exp(token_loss)
    # confidence interval error
    error = Z * np.sqrt(np.var(token_losses) / total_batches)
    return subword_loss, token_loss, error
Ejemplo n.º 2
0
def evaluate(model, dataset, bar):
    """
  Description: Performs evaluation for one epoch
  """
    if CONFIDENCE_INTERVAL == 0.95:
        Z = 1.96
    elif CONFIDENCE_INTERVAL == 0.99:
        Z = 2.58

    total_subword_loss = 0.0
    total_token_loss = 0.0
    total_batches = 0
    token_losses = []

    for (batch, (hole_window, hole_target,
                 seq_len_hole_target)) in enumerate(dataset):

        hole_window = tf.squeeze(hole_window, axis=1)
        hole_target = tf.squeeze(hole_target, axis=1)
        seq_len_hole_target = tf.squeeze(seq_len_hole_target, axis=1)

        batch_token_loss, masked_loss = losses.hole_loss(
            model, hole_window, hole_target, seq_len_hole_target, False)

        batch_subword_loss = tf.reduce_mean(
            tf.reduce_sum(masked_loss, axis=1) /
            tf.cast(seq_len_hole_target, dtype=tf.float32))
        total_subword_loss += batch_subword_loss.numpy()
        total_token_loss += batch_token_loss.numpy()
        token_losses.append(batch_token_loss.numpy())
        total_batches += 1

        if total_batches % 10 == 0:
            bar.update(10)
            postfix = OrderedDict(
                batch_loss={batch_token_loss.experimental_ref()})
            bar.set_postfix(postfix)

    # Calculate mean batch_wise losses
    subword_loss = total_subword_loss / total_batches
    token_loss = total_token_loss / total_batches
    # confidence interval error
    error = Z * np.sqrt(np.var(token_losses) / total_batches)
    return subword_loss, token_loss, error
Ejemplo n.º 3
0
def evaluate(model, dataset, method, bar, inner_learning_rate, sup_batch_size,
             num_of_updates):
    """
  Description: Calculates the average token hole target cross-entropy across the dataset after performing inner updates using support set for each hole target. Both the optimizer state and model parameters
               are reset before calculating the next hole target cross-entropy.
  """
    if CONFIDENCE_INTERVAL == 0.95:
        Z = 1.96
    elif CONFIDENCE_INTERVAL == 0.99:
        Z = 2.58
    token_losses = []
    # dict with key = hole features, value = hole loss
    hole_features = {}

    total_subword_loss = 0.0
    total_token_loss = 0.0
    total_batches = 0

    # storing initial weights of the base model so that they can be restored later
    trained_model_trainable_variables = []
    for entry in model.get_weights():
        trained_model_trainable_variables.append(entry)

    for (batch,
         (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token,
          seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset):

        #Reset before evaluating each hole target
        y = tf.reshape(tf.Variable(1, dtype=tf.int32), (1, 1))
        model(y, y, False)
        model.set_weights(trained_model_trainable_variables)

        if sup_flag:
            sup_window = tf.squeeze(sup_window, axis=0)
            sup_token = tf.squeeze(sup_token, axis=0)
            seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0)

        if sup_flag and (method == 'tssa' or method == 'dyn_eval'):
            # Get the new model object after performing num_of_updates inner updates
            model_new = losses.inner_loss_eval(model, sup_window, sup_token,
                                               seq_len_sup_token, False,
                                               method, inner_learning_rate,
                                               sup_batch_size, num_of_updates)
            # Calculate the hole target loss using the new model object
            batch_token_loss, masked_loss = losses.hole_loss(
                model_new, hole_window, hole_target, seq_len_hole_target,
                False)

        # If there are no support tokens found in the file or if the evaluation mode is for a base_model, directly calculate the hole target loss using the hole window
        if not sup_flag or method == 'base_model':
            batch_token_loss, masked_loss = losses.hole_loss(
                model, hole_window, hole_target, seq_len_hole_target, False)

        batch_subword_loss = tf.reduce_mean(
            tf.reduce_sum(masked_loss, axis=1) /
            tf.cast(seq_len_hole_target, dtype=tf.float32))
        token_loss = batch_token_loss.numpy()

        total_subword_loss += batch_subword_loss.numpy()
        total_token_loss += token_loss
        token_losses.append(token_loss)

        hole_features[hole_identity.numpy()] = token_loss
        total_batches += 1

        if total_batches % 10 == 0:
            bar.update(10)
            postfix = OrderedDict(batch_loss={token_loss})
            bar.set_postfix(postfix)

    # Calculate mean batch_wise losses
    subword_loss = total_subword_loss / total_batches
    token_loss = total_token_loss / total_batches
    # confidence interval error
    error = Z * np.sqrt(np.var(token_losses) / total_batches)
    return subword_loss, token_loss, error, hole_features
Ejemplo n.º 4
0
def train(model, optimizer_inner, optimizer_outer, dataset, train_method, bar, epsilon_reptile, batch_size_sup, num_of_updates):
  """
  Description: Performs meta-training for one epoch
  """
  if CONFIDENCE_INTERVAL == 0.95:
      Z = 1.96
  elif CONFIDENCE_INTERVAL == 0.99:
      Z = 2.58
  token_losses = []

  #During training, we want different holes to be sampled from the same file across epochs. So we do not want a fixed seed
  np.random.seed(None)

  total_subword_loss = 0.0
  total_token_loss = 0.0
  total_batches = 0
  for (batch, (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token, seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset):

    if sup_flag:
      sup_window = tf.squeeze(sup_window, axis=0)
      sup_token = tf.squeeze(sup_token, axis=0)
      seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0)

    #Storing weights for use in reptile later
    old_model_trainable_variables = []
    for entry in model.get_weights():
      old_model_trainable_variables.append(entry)

    # Get the new model instance after doing num_of_updates inner updates and then calculate the gradient of the hole loss w.r.t the updated parameters to give the outer update
    if sup_flag and train_method=='fomaml':
      model_new = losses.support_loss_train(model, sup_window, sup_token, seq_len_sup_token, True, optimizer_inner, batch_size_sup, num_of_updates)
      with tf.GradientTape() as g:
        batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, True)
      grads = g.gradient(batch_token_loss, model.trainable_variables)
      optimizer_outer.apply_gradients(losses.clip_gradients(zip(grads, model.trainable_variables)))

    # Get the new model instance after doing num_of_updates inner updates and then calculate the outer update of reptile
    if sup_flag and train_method=='reptile':
      model_new = losses.support_loss_train(model, sup_window, sup_token, seq_len_sup_token, True, optimizer_inner, batch_size_sup, num_of_updates)
      batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, True)
      new_weights = []
      for i in range(len(model_new.trainable_variables)):
        new_weights.append(old_model_trainable_variables[i] + epsilon_reptile*(model_new.trainable_variables[i]-old_model_trainable_variables[i]))
      model.set_weights(new_weights)

    # If there are no support tokens found in the file directly calculate the gradient of the hole target loss
    if not sup_flag:
      with tf.GradientTape() as g:
        batch_token_loss, masked_loss = losses.hole_loss(model, hole_window, hole_target, seq_len_hole_target, True)
      grads = g.gradient(batch_token_loss, model.trainable_variables)
      optimizer_outer.apply_gradients(evaluation_new.clip_gradients(zip(grads, model.trainable_variables)))

    batch_subword_loss = tf.reduce_mean(tf.reduce_sum(masked_loss, axis=1)/ tf.cast(seq_len_hole_target, dtype=tf.float32))
    token_loss = batch_token_loss.numpy()
    total_subword_loss += batch_subword_loss.numpy()
    total_token_loss += token_loss
    token_losses.append(token_loss)
    total_batches += 1

    if total_batches%10==0:
        bar.update(10)
        postfix = OrderedDict(batch_loss={batch_token_loss.experimental_ref()})
        bar.set_postfix(postfix)

  # Calculate mean batch_wise losses
  subword_loss = total_subword_loss/ total_batches
  token_loss = total_token_loss/ total_batches
   # confidence interval error
  error = Z*np.sqrt(np.var(token_losses)/total_batches)
  return subword_loss, token_loss, error
Ejemplo n.º 5
0
def evaluate(model, dataset, bar, inner_learning_rate, sup_batch_size, num_of_updates):
  """
  Description: Performs evaluation for one epoch. Both the optimizer state and model parameters are reset before calculating the next hole target cross-entropy.
  """
  if CONFIDENCE_INTERVAL == 0.95:
      Z = 1.96
  elif CONFIDENCE_INTERVAL == 0.99:
      Z = 2.58

  np.random.seed(42)
  token_losses = []
  hole_features = {}

  total_subword_loss = 0.0
  total_token_loss = 0.0
  total_batches = 0

  trained_model_trainable_variables = []
  for entry in model.get_weights():
    trained_model_trainable_variables.append(entry)

  for (batch, (hole_window, hole_target, seq_len_hole_target, sup_window, sup_token, seq_len_sup_token, hole_identity, sup_flag)) in enumerate(dataset):

    #Reset before evaluating each hole target
    y = tf.reshape(tf.Variable(1, dtype=tf.int32), (1,1))
    model(y, y, False)
    model.set_weights(trained_model_trainable_variables)

    if sup_flag:
      sup_window = tf.squeeze(sup_window, axis=0)
      sup_token = tf.squeeze(sup_token, axis=0)
      seq_len_sup_token = tf.squeeze(seq_len_sup_token, axis=0)

      model_new = losses.inner_loss_eval(model, sup_window, sup_token, seq_len_sup_token, False, 'tssa', inner_learning_rate, sup_batch_size, num_of_updates)
      batch_token_loss, masked_loss = losses.hole_loss(model_new, hole_window, hole_target, seq_len_hole_target, False)

    # If there are no support tokens found in the file directly calculate the hole target loss using the hole window
    if not sup_flag:
      batch_token_loss, masked_loss = losses.hole_loss(model, hole_window, hole_target, seq_len_hole_target, False)

    batch_subword_loss = tf.reduce_mean(tf.reduce_sum(masked_loss, axis=1)/ tf.cast(seq_len_hole_target, dtype=tf.float32))
    token_loss = batch_token_loss.numpy()

    total_subword_loss += batch_subword_loss.numpy()
    total_token_loss += token_loss
    token_losses.append(token_loss)

    hole_features[hole_identity.numpy()]=token_loss
    total_batches += 1

    if total_batches%10==0:
        bar.update(10)
        postfix = OrderedDict(batch_loss={batch_token_loss.experimental_ref()})
        bar.set_postfix(postfix)

  #For training for next epoch
  model.set_weights(trained_model_trainable_variables)

  # Calculate mean batch_wise losses
  subword_loss = total_subword_loss/ total_batches
  token_loss = total_token_loss/ total_batches
  # confidence interval error
  error = Z * np.sqrt(np.var(token_losses)/total_batches)
  return subword_loss, token_loss, error, hole_features