def draft_summary_beam_search(model, input_ids, enc_output, dec_padding_mask,
                              beam_size):

    log.info(f"Building: 'Draft beam search decoder'")
    input_ids = tfa.seq2seq.tile_batch(input_ids, multiplier=beam_size)
    enc_output = tfa.seq2seq.tile_batch(enc_output, multiplier=beam_size)
    dec_padding_mask = tfa.seq2seq.tile_batch(dec_padding_mask,
                                              multiplier=beam_size)

    def beam_search_decoder(output):
        # (batch_size, seq_len, d_bert)
        embeddings = model.embedding(output)
        predictions, attention_weights = model.decoder(input_ids, embeddings,
                                                       enc_output, False, None,
                                                       dec_padding_mask)
        # (batch_size, 1, target_vocab_size)
        return (predictions[:, -1:, :])

    return beam_search(beam_search_decoder, [CLS_ID] * h_parms.batch_size,
                       beam_size,
                       config.summ_length,
                       config.input_vocab_size,
                       h_parms.length_penalty,
                       stop_early=False,
                       eos_id=[[SEP_ID]])
Beispiel #2
0
def _embedding_from_bert():
    log.info("Extracting pretrained word embeddings weights from BERT")
    vocab_of_BERT = TFBertModel.from_pretrained(config.pretrained_bert_model,
                                                trainable=False)
    embedding_matrix = vocab_of_BERT.get_weights()[0]
    log.info(f"Embedding matrix shape '{embedding_matrix.shape}'")
    return (embedding_matrix, vocab_of_BERT)
def _embedding_from_bert():
    log.info("Extracting pretrained word embeddings weights from BERT")
    BERT_MODEL_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1"
    vocab_of_BERT = hub.KerasLayer(BERT_MODEL_URL, trainable=False)
    embedding_matrix = vocab_of_BERT.get_weights()[0]
    log.info(f"Embedding matrix shape '{embedding_matrix.shape}'")
    return (embedding_matrix, vocab_of_BERT)
Beispiel #4
0
def draft_summary_beam_search(input_ids, enc_output, dec_padding_mask,
                              beam_size):

    log.info(f"Building: 'Draft beam search decoder'")
    input_ids = tfa.seq2seq.tile_batch(input_ids, multiplier=beam_size)
    enc_output = tfa.seq2seq.tile_batch(enc_output, multiplier=beam_size)
    dec_padding_mask = tfa.seq2seq.tile_batch(dec_padding_mask,
                                              multiplier=beam_size)

    #print(f'output_before {tf.shape(output)}')
    def beam_search_decoder(output):
        # (batch_size, seq_len, d_bert)
        embeddings = model.embedding(output)
        predictions, dec_op, attention_weights = model.decoder(
            embeddings, enc_output, False, None, dec_padding_mask)
        if config.copy_gen:
            predictions = model.decoder.pointer_generator(
                dec_op[:, -1:, :],
                predictions[:, -1:, :],
                attention_weights[:, :, -1:, :],
                input_ids,
                tf.shape(input_ids)[1],
                tf.shape(predictions[:, -1:, :])[1],
                training=False,
            )
        # (batch_size, 1, target_vocab_size)
        return (predictions[:, -1:, :])

    return beam_search(beam_search_decoder, [CLS_ID] * h_parms.batch_size,
                       beam_size,
                       config.summ_length,
                       config.input_vocab_size,
                       h_parms.length_penalty,
                       stop_early=False,
                       eos_id=[[SEP_ID]])
Beispiel #5
0
def refined_summary_sampling(inp,
                             enc_output,
                             draft_summary,
                             padding_mask,
                             sampling_type='greedy',
                             temperature=0.9,
                             p=0.9,
                             k=25,
                             beam_search=False,
                             training=False):
    """
        Inference call, builds a refined summary
        
        It first masks each word in the summary draft one by one,
        then feeds the draft to BERT to generate context vectors.
        """

    log.info(f"Building: 'Refined {sampling_type} decoder'")
    N = tf.shape(enc_output)[0]
    refined_summary = draft_summary
    batch = tf.shape(draft_summary)[0]
    print(f'draft_summary {tf.shape(draft_summary)}')
    dec_outputs = []
    dec_logits = []
    attention_dists = []
    for i in (range(1, config.summ_length)):

        # (batch_size, seq_len)
        refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID)

        # (batch_size, seq_len, d_bert)
        context_vectors = model.bert_model(refined_summary_)[0]

        # (batch_size, seq_len, d_bert), (_)
        dec_output, dec_logits_i, attention_dist = model.decoder(
            context_vectors,
            enc_output,
            training=training,
            look_ahead_mask=None,
            padding_mask=padding_mask)

        # (batch_size, 1, vocab_len)
        dec_output_i = dec_output[:, i:i + 1, :]
        if sampling_type == 'nucleus':
            preds = tf.cast(
                nucleus_sampling((dec_output_i / temperature), p=p), tf.int32)
        elif sampling_type == 'topk':
            preds = tf.cast(
                top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32)
        elif sampling_type == 'topktopp':
            preds = tf.cast(
                topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32)
        elif sampling_type == 'random_sampling':
            preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32)
        else:
            preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32)
        refined_summary = with_column(refined_summary, i, preds)
    # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
    return refined_summary, attention_dist
def draft_summary_sampling(model,
                           inp,
                           enc_output,
                           look_ahead_mask,
                           padding_mask,
                           sampling_type='greedy',
                           temperature=0.9,
                           p=0.9,
                           k=25,
                           training=False):
    """
    Inference call, builds a draft summary auto-regressively
    """
    log.info(f"Building: 'Draft {sampling_type} decoder'")
    N = tf.shape(enc_output)[0]
    T = tf.shape(enc_output)[1]

    # (batch_size, 1)
    dec_input = tf.ones([N, 1], dtype=tf.int32) * CLS_ID
    summary, dec_outputs, dec_logits, attention_dists = [], [], [], []
    summary += [dec_input]
    for i in (range(0, config.summ_length)):
        _, _, dec_padding_mask = create_masks(inp, dec_input)
        # (batch_size, i+1, d_bert)
        embeddings = model.embedding(dec_input)

        # (batch_size, i+1, vocab), (_)
        dec_output, attention_dist = model.decoder(inp, embeddings, enc_output,
                                                   training, look_ahead_mask,
                                                   padding_mask)

        # (batch_size, 1, vocab)
        dec_output_i = dec_output[:, -1:, :]
        if sampling_type == 'nucleus':
            preds = tf.cast(
                nucleus_sampling(((dec_output_i) / temperature), p=p),
                tf.int32)
        elif sampling_type == 'topk':
            preds = tf.cast(
                top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32)
        elif sampling_type == 'random_sampling':
            preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32)
        elif sampling_type == 'topktopp':
            preds = tf.cast(
                topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32)
        else:
            preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32)
        dec_outputs += [dec_output_i]
        #dec_logits_i = dec_logits_i[:, -1:, :]
        #dec_logits += [dec_logits_i]
        summary += [preds]
        dec_input = with_column(dec_input, i + 1, preds)
    summary = tf.concat(summary, axis=1)
    # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
    return summary, attention_dist
Beispiel #7
0
def infer_data_from_df(num_of_infer_examples=config.num_examples_to_infer):
    doc, summ = create_dataframe(file_path.infer_csv_path,
                                 num_of_infer_examples)
    infer_examples = tf.data.Dataset.from_tensor_slices((doc, summ))
    infer_buffer_size = len(doc)
    infer_dataset = map_batch_shuffle(infer_examples,
                                      infer_buffer_size,
                                      split='infer',
                                      batch_size=1)
    log.info('infer tf_dataset created')
    return infer_dataset
Beispiel #8
0
def count_recs(batch, epoch, num_of_train_examples):
    if epoch == 0:
        try:
            if batch > 0:
                num_of_recs_post_filter_atmost = (
                    (batch) * h_parms.batch_size) / num_of_train_examples
                log.info(
                    f'Percentage of records used for training should be close to {num_of_recs_post_filter_atmost*100 :.2f}'
                )
        except NameError:
            log.info('End of epoch')
def check_ckpt(checkpoint_path):
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=10)
    if tf.train.latest_checkpoint(checkpoint_path):
        ckpt.restore(ckpt_manager.latest_checkpoint)
        log.info(ckpt_manager.latest_checkpoint + ' restored')
        latest_ckpt = int(ckpt_manager.latest_checkpoint[-2:])
    else:
        latest_ckpt = 0
        log.info('Training from scratch')
    return (ckpt_manager, latest_ckpt, ckpt)
Beispiel #10
0
def create_train_data(num_samples_to_train=config.num_examples_to_train,
                      shuffle=True,
                      filter_off=False):

    if config.use_tfds:
        examples, metadata = tfds.load(
            config.tfds_name,
            with_info=True,
            as_supervised=True,
            data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset',
            builder_kwargs={"version": "3.0.0"})
        length = 100
        start = np.random.randint(2e5 - length - 1, size=1)[0]
        # examples, metadata = tfds.load(
        #     config.tfds_name,
        #     with_info=True,
        #     as_supervised=True,
        #     data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset',
        #     builder_kwargs={"version": "3.0.0"},split = tfds.core.ReadInstruction('train', from_=start, to=start+length, unit='abs')
        # )

        other_ds = 'validation' if 'validation' in examples else 'test'
        train_examples = examples['train']
        valid_examples = examples[other_ds]
        # train_examples = examples
        # valid_examples = examples
        train_buffer_size = metadata.splits['train'].num_examples
        valid_buffer_size = metadata.splits[other_ds].num_examples
    else:
        doc, summ = create_dataframe(file_path.train_csv_path,
                                     num_samples_to_train)
        X_train, X_test, y_train, y_test = train_test_split(
            doc, summ, test_size=config.test_size, random_state=42)
        train_examples = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        valid_examples = tf.data.Dataset.from_tensor_slices((X_test, y_test))
        train_buffer_size = len(X_train)
        valid_buffer_size = len(X_test)
    train_dataset = map_batch_shuffle(train_examples,
                                      train_buffer_size,
                                      split='train',
                                      shuffle=shuffle,
                                      batch_size=h_parms.batch_size,
                                      filter_off=filter_off)
    valid_dataset = map_batch_shuffle(valid_examples,
                                      valid_buffer_size,
                                      split='valid',
                                      batch_size=h_parms.validation_batch_size,
                                      filter_off=filter_off)
    log.info('Train and Test tf_datasets created')
    return (train_dataset, valid_dataset, train_buffer_size, valid_buffer_size)
def check_ckpt(checkpoint_path):
    ckpt = tf.train.Checkpoint(
                               transformer=transformer,
                               optimizer=optimizer
                              )
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, keep_checkpoint_every_n_hours=0.3, max_to_keep=20)
    if tf.train.latest_checkpoint(checkpoint_path):
      ckpt.restore(ckpt_manager.latest_checkpoint)
      log.info(ckpt_manager.latest_checkpoint +' restored')
      latest_ckpt = int(ckpt_manager.latest_checkpoint[-2:])
    else:
        latest_ckpt=0
        log.info('Training from scratch')
    return (ckpt_manager, latest_ckpt)
Beispiel #12
0
def draft_summary_beam_search(input_ids, beam_size):

    log.info(f"Building: 'Draft beam search decoder'")

    batch = tf.shape(input_ids)[0]
    end = [SEP_ID]
    # (batch_size, seq_len, d_bert)
    enc_output_ = model.bert_model(input_ids)[0]
    enc_output = tf.tile(enc_output_, multiples=[beam_size,1, 1])
    input_ids = tf.tile(input_ids, multiples=[beam_size, 1])
    # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len)
    dec_input = tf.convert_to_tensor([CLS_ID] * batch)
    output = tf.expand_dims(dec_input, 0)
    def beam_search_decoder(output):
      _, _, dec_padding_mask = create_masks(input_ids, output)    
      embeddings = model.embedding(output)
      predictions, dec_op, attention_weights = model.decoder(
                                                            input_ids, 
                                                            embeddings, 
                                                            enc_output, 
                                                            False, 
                                                            None, 
                                                            dec_padding_mask
                                                            )
      if config.copy_gen:
        predictions = model.decoder.pointer_generator(
                                                      dec_op, 
                                                      predictions,
                                                      attention_weights,
                                                      input_ids,
                                                      tf.shape(input_ids)[1], 
                                                      tf.shape(output)[-1], 
                                                      False
                                                     )
      # (batch_size, 1, target_vocab_size)
      return (predictions[:,-1:,:])
    return (beam_search(
                        beam_search_decoder, 
                        dec_input, 
                        beam_size, 
                        config.summ_length, 
                        config.input_vocab_size, 
                        h_parms.length_penalty, 
                        stop_early=False, 
                        eos_id=[end]
                        ),
                        enc_output_
            )
Beispiel #13
0
def create_train_data(num_samples_to_train=config.num_examples_to_train,
                      shuffle=True,
                      filter_off=False):

    if config.use_tfds:
        train_examples, _ = tfds.load(
            config.tfds_name,
            with_info=True,
            as_supervised=True,
            data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset',
            builder_kwargs={"version": "2.0.0"},
            split=tfds.core.ReadInstruction('train',
                                            from_=90,
                                            to=100,
                                            unit='%'))
        valid_examples, _ = tfds.load(
            config.tfds_name,
            with_info=True,
            as_supervised=True,
            data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset',
            builder_kwargs={"version": "2.0.0"},
            split='validation')
        train_buffer_size = 287113
        valid_buffer_size = 13368
    else:
        doc, summ = create_dataframe(file_path.train_csv_path,
                                     num_samples_to_train)
        X_train, X_test, y_train, y_test = train_test_split(
            doc, summ, test_size=config.test_size, random_state=42)
        train_examples = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        valid_examples = tf.data.Dataset.from_tensor_slices((X_test, y_test))
        train_buffer_size = len(X_train)
        valid_buffer_size = len(X_test)
    train_dataset = map_batch_shuffle(train_examples,
                                      train_buffer_size,
                                      split='train',
                                      shuffle=shuffle,
                                      batch_size=h_parms.batch_size,
                                      filter_off=filter_off)
    valid_dataset = map_batch_shuffle(valid_examples,
                                      valid_buffer_size,
                                      split='valid',
                                      batch_size=h_parms.validation_batch_size,
                                      filter_off=filter_off)
    log.info('Train and Test tf_datasets created')
    return (train_dataset, valid_dataset, train_buffer_size, valid_buffer_size)
def train_data_from_tfds():
    examples, metadata = tfds.load('gigaword', with_info=True, as_supervised=True)
    train_buffer_size = metadata.splits['train'].num_examples
    valid_buffer_size = metadata.splits['test'].num_examples
    train_dataset = batch_shuffle(
                                  examples['train'], 
                                  train_buffer_size, 
                                  split = 'train',
                                  batch_size=config.batch_size
                                  )
    valid_dataset = batch_shuffle(
                                  examples['test'], 
                                  valid_buffer_size, 
                                  split='test',
                                  batch_size=config.batch_size
                                  )
    log.info('Training and Test set created')
    return train_dataset, valid_dataset
def _embedding_from_bert():

    log.info("Extracting pretrained word embeddings weights from BERT")
    #BERT_MODEL_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1"

    # dinput_word_ids = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32,
    #                                       name="input_word_ids")
    # dinput_mask = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32,
    #                                   name="input_mask")
    # dsegment_ids = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32,
    #                                     name="segment_ids")
    # einput_word_ids = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32,
    #                                       name="input_word_ids")
    # einput_mask = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32,
    #                                   name="input_mask")
    # esegment_ids = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32,
    #                                     name="segment_ids")
    #bert_layer = hub.KerasLayer(BERT_MODEL_URL, trainable=False)

    vocab_of_BERT = TFBertModel.from_pretrained('bert-base-uncased',
                                                trainable=False)
    embedding_matrix = vocab_of_BERT.get_weights()[0]
    # trainable_vars = vocab_of_BERT.variables
    # # Remove unused layers
    # trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]

    # # Select how many layers to fine tune
    # trainable_vars = []

    # # Add to trainable weights
    # for var in trainable_vars:
    #     vocab_of_BERT.trainable_weights.append(var)

    # for var in vocab_of_BERT.variables:
    #     if var not in vocab_of_BERT.trainable_weights:
    #         vocab_of_BERT.non_trainable_weights.append(var)
    #_, dsequence_output = vocab_of_BERT([dinput_word_ids, dinput_mask, dsegment_ids])
    #_, esequence_output = vocab_of_BERT([einput_word_ids, einput_mask, esegment_ids])
    #dec_model = tf.keras.models.Model(inputs=[dinput_word_ids, dinput_mask, dsegment_ids], outputs=dsequence_output)
    #enc_model = tf.keras.models.Model(inputs=[einput_word_ids, einput_mask, esegment_ids], outputs=esequence_output)
    log.info(f"Embedding matrix shape '{embedding_matrix.shape}'")
    return (embedding_matrix, vocab_of_BERT)
def monitor_run(latest_ckpt, 
                ckpt_save_path, 
                val_loss, 
                val_acc,
                bert_score, 
                rouge_score, 
                valid_summary_writer,
                step,
                to_monitor=config.monitor_metric):
  
  ckpt_fold, ckpt_string = os.path.split(ckpt_save_path)
  if config.run_tensorboard:
    with valid_summary_writer.as_default():
      tf.summary.scalar('validation_total_loss', val_acc, step=step)
      tf.summary.scalar('validation_total_accuracy', val_loss, step=step)
      tf.summary.scalar('ROUGE_f1', rouge_score, step=step)
      tf.summary.scalar('BERT_f1', bert_score, step=step)
  monitor_metrics = dict()
  monitor_metrics['validation_loss'] = val_loss
  monitor_metrics['validation_accuracy'] = val_acc
  monitor_metrics['BERT_f1'] = bert_score
  monitor_metrics['ROUGE_f1'] = rouge_score
  monitor_metrics['combined_metric'] = (
                                        monitor_metrics['BERT_f1'], 
                                        monitor_metrics['ROUGE_f1'], 
                                        monitor_metrics['validation_accuracy']
                                        )
  # multiply with the weights                                    
  monitor_metrics['combined_metric'] = round(tf.reduce_sum([(i*j) for i,j in zip(monitor_metrics['combined_metric'],  
                                                                                 h_parms.combined_metric_weights)]).numpy(), 2)
  log.info(f"combined_metric {monitor_metrics['combined_metric']:4f}")
  if to_monitor != 'validation_loss':
    cond = (config.last_recorded_value < monitor_metrics[to_monitor])
  else:
    cond = (config.last_recorded_value > monitor_metrics[monitor])
  if (latest_ckpt > config.monitor_only_after) and cond:
    # reset tolerance to zero if the monitor_metric decreases before the tolerance threshold
    config.init_tolerance=0
    config.last_recorded_value =  monitor_metrics[to_monitor]
    ckpt_files_tocopy = [files for files in os.listdir(os.path.split(ckpt_save_path)[0]) \
                         if ckpt_string in files]
    log.info(f'{to_monitor} is {monitor_metrics[to_monitor]:4f} so checkpoint files {ckpt_string}           \
             will be copied to best checkpoint directory')
    # copy the best checkpoints
    shutil.copy2(os.path.join(ckpt_fold, 'checkpoint'), file_path.best_ckpt_path)
    for files in ckpt_files_tocopy:
        shutil.copy2(os.path.join(ckpt_fold, files), file_path.best_ckpt_path)
  else:
    config.init_tolerance+=1
  # Warn and early stop
  if config.init_tolerance > config.tolerance_threshold:
    log.warning('Tolerance exceeded')
  if config.early_stop and config.init_tolerance > config.tolerance_threshold:
    log.info(f'Early stopping since the {to_monitor} reached the tolerance threshold')
    return False
  else:
    return True
Beispiel #17
0
def batch_run_check(batch, start, train_summary_writer, train_loss,
                    train_accuracy, model):
    if config.run_tensorboard:
        with train_summary_writer.as_default():
            tf.summary.scalar('train_loss', train_loss, step=batch)
            tf.summary.scalar('train_accuracy', train_accuracy, step=batch)
    if batch == 0:
        log.info(model.summary())
        log.info(batch_zero.format(time.time() - start))
    log.info(batch_run_details.format(batch, train_loss, train_accuracy))
def batch_run_check(batch, epoch, start, train_summary_writer, train_loss,
                    train_accuracy, transformer):
    if config.run_tensorboard:
        with train_summary_writer.as_default():
            tf.summary.scalar('train_loss', train_loss, step=batch)
            tf.summary.scalar('train_accuracy', train_accuracy, step=batch)
    if batch == 0 and epoch == 0:
        log.info(transformer.summary())
        log.info(batch_zero.format(time.time() - start))
    if batch % config.print_chks == 0:
        log.info(
            batch_run_details.format(epoch + 1, batch, train_loss,
                                     train_accuracy))
Beispiel #19
0
def refined_summary_sampling(inp, 
                           enc_output, 
                           draft_summary, 
                           padding_mask, 
                           sampling_type='greedy', 
                           temperature=0.9, 
                           p=0.9, 
                           k=25,
                           beam_search=False,
                           training=False):
        """
        Inference call, builds a refined summary
        
        It first masks each word in the summary draft one by one,
        then feeds the draft to BERT to generate context vectors.
        """
        
        log.info(f"Building: 'Refined {sampling_type} decoder'")
        N = tf.shape(enc_output)[0]
        refined_summary = tf.expand_dims(draft_summary,0)
        dec_outputs = []
        dec_logits = []
        attention_dists = []
        for i in (range(1, config.summ_length)):
            
            # (batch_size, seq_len)
            refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID)
            # (batch_size, seq_len, d_bert)
            context_vectors = model.bert_model(refined_summary_)[0]
            # (batch_size, seq_len, d_bert), (_)
            dec_output, dec_logits_i, attention_dist = model.decoder(
                                                                    inp,
                                                                    context_vectors,
                                                                    enc_output,
                                                                    training=training,
                                                                    look_ahead_mask=None,
                                                                    padding_mask=padding_mask
                                                                  )
            
            # (batch_size, 1, vocab_len)
            dec_output_i = dec_output[:, i:i+1 ,:]
            if sampling_type == 'nucleus':
              preds = tf.cast(nucleus_sampling((tf.squeeze(dec_output_i)/ temperature), p=p), tf.int32)
            elif sampling_type == 'topk':
              preds = tf.cast(top_k_sampling((tf.squeeze(dec_output_i)/ temperature), k=k), tf.int32)
            elif sampling_type == 'random_sampling':
              preds = tf.cast(sampling(tf.squeeze(dec_output_i)/ temperature), tf.int32)
            else:
              preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32)
            dec_outputs += [dec_output_i]
            dec_logits_i = dec_logits_i[:, i:i+1, :]
            dec_logits += [dec_logits_i]
            
            refined_summary = with_column(refined_summary, i, preds)
            attention_dists += [attention_dist[:, i:i+1, :]]
        cls_concat_dec_outputs = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.target_vocab_size), axis=0), [N, 1, 1]))
        cls_concat_dec_logits = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.d_model), axis=0), [N, 1, 1]))
        dec_outputs = tf.reshape(dec_outputs, (1, -1, config.target_vocab_size))
        dec_logits = tf.reshape(dec_logits, (1, -1, config.d_model))
        attention_dists = tf.reshape(attention_dists, (1, -1, config.doc_length))
        dec_outputs = tf.concat([cls_concat_dec_outputs, dec_outputs], axis=1)
        dec_logits = tf.concat([cls_concat_dec_logits, dec_logits], axis=1)
        
        if config.copy_gen: 
          predictions = model.decoder.pointer_generator(
                                                        dec_logits,
                                                        dec_outputs, 
                                                        attention_dists, 
                                                        inp, 
                                                        tf.shape(inp)[-1], 
                                                        tf.shape(dec_outputs)[1], 
                                                        training=training
                                                        )
          refined_summary = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32)
        # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)        
        return tf.squeeze(refined_summary, axis=0), attention_dist
                            model)
        eval_frequency = ((step + 1) * h_parms.batch_size) % config.eval_after
        if eval_frequency == 0:
            predicted = (tokenizer.decode([
                i for i in tf.squeeze(tf.argmax(refine_predictions, axis=-1))
                if i not in [101, 102, 0]
            ]))
            target = (tokenizer.decode(
                [i for i in tf.squeeze(target_x) if i not in [101, 102, 0]]))
            print(f'the golden summary is {target}')
            print(
                f'the predicted summary is {predicted if predicted else "EMPTY"}'
            )
            ckpt_save_path = ck_pt_mgr.save()
            (val_acc, rouge_score,
             bert_score) = calc_validation_loss(val_dataset, step + 1,
                                                val_step, valid_summary_writer,
                                                validation_accuracy)

            latest_ckpt += (step + 1)
            log.info(
                model_metrics.format(step + 1, train_loss.result(),
                                     train_accuracy.result(), val_acc,
                                     rouge_score, bert_score))
            log.info(evaluation_step.format(step + 1, time.time() - start))
            log.info(checkpoint_details.format(step + 1, ckpt_save_path))
            if not monitor_run(latest_ckpt, ckpt_save_path, val_acc,
                               bert_score, rouge_score, valid_summary_writer,
                               step + 1):
                break

def tf_encode(doc, summary):
    return tf.py_function(encode, [doc, summary], [tf.int64, tf.int64])
    
def batch_shuffle(dataset, buffer_size, split, batch_size=config.batch_size):
    tf_dataset = dataset.map(tf_encode, num_parallel_calls=AUTOTUNE)                           
    tf_dataset = tf_dataset.filter(filter_token_size)
    sum_of_records = sum(1 for l in tf_dataset)                                                       #optimize
    if sum_of_records > 2,000:
        tf_dataset = tf_dataset.cache()
    if buffer_size:
        tf_dataset = tf_dataset.shuffle(buffer_size, seed = 100)
    tf_dataset = tf_dataset.padded_batch(batch_size, padded_shapes=([-1], [-1]))
    tf_dataset = tf_dataset.prefetch(buffer_size=AUTOTUNE)
    log.info(f'Number of records {split} filtered {buffer_size - sum_of_records}')
    log.info(f'Number of records to be {split}ed {sum_of_records}')
    return tf_dataset
  

def train_data_from_tfds():
    examples, metadata = tfds.load('gigaword', with_info=True, as_supervised=True)
    train_buffer_size = metadata.splits['train'].num_examples
    valid_buffer_size = metadata.splits['test'].num_examples
    train_dataset = batch_shuffle(
                                  examples['train'], 
                                  train_buffer_size, 
                                  split = 'train',
                                  batch_size=config.batch_size
                                  )
    valid_dataset = batch_shuffle(
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import shutil
import os
from configuration import config
from hyper_parameters import h_parms
from rouge import Rouge
from input_path import file_path
from create_tokenizer import tokenizer
from bert_score import score as b_score
from creates import log, monitor_metrics

log.info('Loading Pre-trained BERT model for BERT SCORE calculation')
_, _, _ = b_score(["I'm Batman"], ["I'm Spiderman"], lang='en', model_type=config.pretrained_bert_model)
rouge_all = Rouge()



class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()
    
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)
    self.warmup_steps = warmup_steps
    
  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)
    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
Beispiel #23
0
         ) == 'yes':
    try:
        shutil.rmtree(file_path.summary_write_path)
        shutil.rmtree(file_path.tensorboard_log)
    except FileNotFoundError:
        pass

train_dataset, val_dataset, num_of_train_examples = create_train_data()
train_loss, train_accuracy = get_loss_and_accuracy()
validation_loss, validation_accuracy = get_loss_and_accuracy()

if config.show_detokenized_samples:
    inp, tar = next(iter(train_dataset))
    for ip, ta in zip(inp.numpy(), tar.numpy()):
        log.info(
            tokenizer_en.decode([i for i in ta
                                 if i < tokenizer_en.vocab_size]))
        log.info(
            tokenizer_en.decode([i for i in ip
                                 if i < tokenizer_en.vocab_size]))
        break

transformer = Transformer(num_layers=config.num_layers,
                          d_model=config.d_model,
                          num_heads=config.num_heads,
                          dff=config.dff,
                          input_vocab_size=config.input_vocab_size,
                          target_vocab_size=config.target_vocab_size,
                          rate=h_parms.dropout_rate)
generator = Generator()
for epoch in range(h_parms.epochs):
    start = time.time()
    train_loss.reset_states()
    train_accuracy.reset_states()
    validation_loss.reset_states()
    validation_accuracy.reset_states()
    # inp -> document, tar -> summary
    for (batch, (inp, tar)) in enumerate(train_dataset):
        # the target is shifted right during training hence its shape is subtracted by 1
        # not able to do this inside tf.function since it doesn't allow this operation
        train_step(inp, tar, inp.shape[1], tar.shape[1] - 1, inp.shape[0])
        batch_run_check(batch, epoch, start, train_summary_writer,
                        train_loss.result(), train_accuracy.result(),
                        transformer, pointer_generator)
    count_recs(batch, epoch, num_of_train_examples)
    (val_acc, val_loss, rouge_score,
     bert_score) = calc_validation_loss(val_dataset, epoch + 1, val_step,
                                        valid_summary_writer, validation_loss,
                                        validation_accuracy)
    ckpt_save_path = ck_pt_mgr.save()
    latest_ckpt += epoch
    log.info(
        model_metrics.format(epoch + 1, train_loss.result(),
                             train_accuracy.result(), val_loss, val_acc,
                             rouge_score, bert_score))
    log.info(epoch_timing.format(epoch + 1, time.time() - start))
    log.info(checkpoint_details.format(epoch + 1, ckpt_save_path))
    if not monitor_run(latest_ckpt, ckpt_save_path, val_loss, val_acc,
                       bert_score, rouge_score, valid_summary_writer, epoch):
        break
# if a checkpoint exists, restore the latest checkpoint.
ck_pt_mgr, latest_ckpt = check_ckpt(file_path.checkpoint_path)

for epoch in range(h_parms.epochs):
  start = time.time()  
  train_loss.reset_states()
  train_accuracy.reset_states()
  validation_loss.reset_states()
  validation_accuracy.reset_states()
  # inp -> document, tar -> summary
  for (batch, (inp, tar)) in enumerate(train_dataset):
  # the target is shifted right during training hence its shape is subtracted by 1
  # not able to do this inside tf.function since it doesn't allow this operation
    train_step(inp, tar, inp.shape[1], tar.shape[1]-1, inp.shape[0])        
    if batch==0 and epoch ==0:
      log.info(transformer.summary())
      if config.copy_gen:
        log.info(pointer_generator.summary())
      log.info(batch_zero.format(time.time()-start))
    if batch % config.print_chks == 0:
      log.info(batch_run_details.format(
        epoch + 1, batch, train_loss.result(), train_accuracy.result()))
  if epoch == 0:
    try:
      if batch > 0:
        num_of_recs_post_filter_atmost = ((batch)*h_parms.batch_size)/num_of_train_examples
        num_of_recs_post_filter_atleast = ((batch-1)*h_parms.batch_size)/num_of_train_examples
        log.info(f'Number of records used for training should be in between {num_of_recs_post_filter_atleast*100} - \
                {num_of_recs_post_filter_atmost*100}% of training data')
    except NameError:
      assert False, 'Training dataset is empty'