def draft_summary_beam_search(model, input_ids, enc_output, dec_padding_mask, beam_size): log.info(f"Building: 'Draft beam search decoder'") input_ids = tfa.seq2seq.tile_batch(input_ids, multiplier=beam_size) enc_output = tfa.seq2seq.tile_batch(enc_output, multiplier=beam_size) dec_padding_mask = tfa.seq2seq.tile_batch(dec_padding_mask, multiplier=beam_size) def beam_search_decoder(output): # (batch_size, seq_len, d_bert) embeddings = model.embedding(output) predictions, attention_weights = model.decoder(input_ids, embeddings, enc_output, False, None, dec_padding_mask) # (batch_size, 1, target_vocab_size) return (predictions[:, -1:, :]) return beam_search(beam_search_decoder, [CLS_ID] * h_parms.batch_size, beam_size, config.summ_length, config.input_vocab_size, h_parms.length_penalty, stop_early=False, eos_id=[[SEP_ID]])
def _embedding_from_bert(): log.info("Extracting pretrained word embeddings weights from BERT") vocab_of_BERT = TFBertModel.from_pretrained(config.pretrained_bert_model, trainable=False) embedding_matrix = vocab_of_BERT.get_weights()[0] log.info(f"Embedding matrix shape '{embedding_matrix.shape}'") return (embedding_matrix, vocab_of_BERT)
def _embedding_from_bert(): log.info("Extracting pretrained word embeddings weights from BERT") BERT_MODEL_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1" vocab_of_BERT = hub.KerasLayer(BERT_MODEL_URL, trainable=False) embedding_matrix = vocab_of_BERT.get_weights()[0] log.info(f"Embedding matrix shape '{embedding_matrix.shape}'") return (embedding_matrix, vocab_of_BERT)
def draft_summary_beam_search(input_ids, enc_output, dec_padding_mask, beam_size): log.info(f"Building: 'Draft beam search decoder'") input_ids = tfa.seq2seq.tile_batch(input_ids, multiplier=beam_size) enc_output = tfa.seq2seq.tile_batch(enc_output, multiplier=beam_size) dec_padding_mask = tfa.seq2seq.tile_batch(dec_padding_mask, multiplier=beam_size) #print(f'output_before {tf.shape(output)}') def beam_search_decoder(output): # (batch_size, seq_len, d_bert) embeddings = model.embedding(output) predictions, dec_op, attention_weights = model.decoder( embeddings, enc_output, False, None, dec_padding_mask) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_op[:, -1:, :], predictions[:, -1:, :], attention_weights[:, :, -1:, :], input_ids, tf.shape(input_ids)[1], tf.shape(predictions[:, -1:, :])[1], training=False, ) # (batch_size, 1, target_vocab_size) return (predictions[:, -1:, :]) return beam_search(beam_search_decoder, [CLS_ID] * h_parms.batch_size, beam_size, config.summ_length, config.input_vocab_size, h_parms.length_penalty, stop_early=False, eos_id=[[SEP_ID]])
def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = draft_summary batch = tf.shape(draft_summary)[0] print(f'draft_summary {tf.shape(draft_summary)}') dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i + 1, :] if sampling_type == 'nucleus': preds = tf.cast( nucleus_sampling((dec_output_i / temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast( top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32) elif sampling_type == 'topktopp': preds = tf.cast( topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) refined_summary = with_column(refined_summary, i, preds) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return refined_summary, attention_dist
def draft_summary_sampling(model, inp, enc_output, look_ahead_mask, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, training=False): """ Inference call, builds a draft summary auto-regressively """ log.info(f"Building: 'Draft {sampling_type} decoder'") N = tf.shape(enc_output)[0] T = tf.shape(enc_output)[1] # (batch_size, 1) dec_input = tf.ones([N, 1], dtype=tf.int32) * CLS_ID summary, dec_outputs, dec_logits, attention_dists = [], [], [], [] summary += [dec_input] for i in (range(0, config.summ_length)): _, _, dec_padding_mask = create_masks(inp, dec_input) # (batch_size, i+1, d_bert) embeddings = model.embedding(dec_input) # (batch_size, i+1, vocab), (_) dec_output, attention_dist = model.decoder(inp, embeddings, enc_output, training, look_ahead_mask, padding_mask) # (batch_size, 1, vocab) dec_output_i = dec_output[:, -1:, :] if sampling_type == 'nucleus': preds = tf.cast( nucleus_sampling(((dec_output_i) / temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast( top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32) elif sampling_type == 'topktopp': preds = tf.cast( topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) dec_outputs += [dec_output_i] #dec_logits_i = dec_logits_i[:, -1:, :] #dec_logits += [dec_logits_i] summary += [preds] dec_input = with_column(dec_input, i + 1, preds) summary = tf.concat(summary, axis=1) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return summary, attention_dist
def infer_data_from_df(num_of_infer_examples=config.num_examples_to_infer): doc, summ = create_dataframe(file_path.infer_csv_path, num_of_infer_examples) infer_examples = tf.data.Dataset.from_tensor_slices((doc, summ)) infer_buffer_size = len(doc) infer_dataset = map_batch_shuffle(infer_examples, infer_buffer_size, split='infer', batch_size=1) log.info('infer tf_dataset created') return infer_dataset
def count_recs(batch, epoch, num_of_train_examples): if epoch == 0: try: if batch > 0: num_of_recs_post_filter_atmost = ( (batch) * h_parms.batch_size) / num_of_train_examples log.info( f'Percentage of records used for training should be close to {num_of_recs_post_filter_atmost*100 :.2f}' ) except NameError: log.info('End of epoch')
def check_ckpt(checkpoint_path): ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10) if tf.train.latest_checkpoint(checkpoint_path): ckpt.restore(ckpt_manager.latest_checkpoint) log.info(ckpt_manager.latest_checkpoint + ' restored') latest_ckpt = int(ckpt_manager.latest_checkpoint[-2:]) else: latest_ckpt = 0 log.info('Training from scratch') return (ckpt_manager, latest_ckpt, ckpt)
def create_train_data(num_samples_to_train=config.num_examples_to_train, shuffle=True, filter_off=False): if config.use_tfds: examples, metadata = tfds.load( config.tfds_name, with_info=True, as_supervised=True, data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset', builder_kwargs={"version": "3.0.0"}) length = 100 start = np.random.randint(2e5 - length - 1, size=1)[0] # examples, metadata = tfds.load( # config.tfds_name, # with_info=True, # as_supervised=True, # data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset', # builder_kwargs={"version": "3.0.0"},split = tfds.core.ReadInstruction('train', from_=start, to=start+length, unit='abs') # ) other_ds = 'validation' if 'validation' in examples else 'test' train_examples = examples['train'] valid_examples = examples[other_ds] # train_examples = examples # valid_examples = examples train_buffer_size = metadata.splits['train'].num_examples valid_buffer_size = metadata.splits[other_ds].num_examples else: doc, summ = create_dataframe(file_path.train_csv_path, num_samples_to_train) X_train, X_test, y_train, y_test = train_test_split( doc, summ, test_size=config.test_size, random_state=42) train_examples = tf.data.Dataset.from_tensor_slices((X_train, y_train)) valid_examples = tf.data.Dataset.from_tensor_slices((X_test, y_test)) train_buffer_size = len(X_train) valid_buffer_size = len(X_test) train_dataset = map_batch_shuffle(train_examples, train_buffer_size, split='train', shuffle=shuffle, batch_size=h_parms.batch_size, filter_off=filter_off) valid_dataset = map_batch_shuffle(valid_examples, valid_buffer_size, split='valid', batch_size=h_parms.validation_batch_size, filter_off=filter_off) log.info('Train and Test tf_datasets created') return (train_dataset, valid_dataset, train_buffer_size, valid_buffer_size)
def check_ckpt(checkpoint_path): ckpt = tf.train.Checkpoint( transformer=transformer, optimizer=optimizer ) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, keep_checkpoint_every_n_hours=0.3, max_to_keep=20) if tf.train.latest_checkpoint(checkpoint_path): ckpt.restore(ckpt_manager.latest_checkpoint) log.info(ckpt_manager.latest_checkpoint +' restored') latest_ckpt = int(ckpt_manager.latest_checkpoint[-2:]) else: latest_ckpt=0 log.info('Training from scratch') return (ckpt_manager, latest_ckpt)
def draft_summary_beam_search(input_ids, beam_size): log.info(f"Building: 'Draft beam search decoder'") batch = tf.shape(input_ids)[0] end = [SEP_ID] # (batch_size, seq_len, d_bert) enc_output_ = model.bert_model(input_ids)[0] enc_output = tf.tile(enc_output_, multiples=[beam_size,1, 1]) input_ids = tf.tile(input_ids, multiples=[beam_size, 1]) # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len) dec_input = tf.convert_to_tensor([CLS_ID] * batch) output = tf.expand_dims(dec_input, 0) def beam_search_decoder(output): _, _, dec_padding_mask = create_masks(input_ids, output) embeddings = model.embedding(output) predictions, dec_op, attention_weights = model.decoder( input_ids, embeddings, enc_output, False, None, dec_padding_mask ) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_op, predictions, attention_weights, input_ids, tf.shape(input_ids)[1], tf.shape(output)[-1], False ) # (batch_size, 1, target_vocab_size) return (predictions[:,-1:,:]) return (beam_search( beam_search_decoder, dec_input, beam_size, config.summ_length, config.input_vocab_size, h_parms.length_penalty, stop_early=False, eos_id=[end] ), enc_output_ )
def create_train_data(num_samples_to_train=config.num_examples_to_train, shuffle=True, filter_off=False): if config.use_tfds: train_examples, _ = tfds.load( config.tfds_name, with_info=True, as_supervised=True, data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset', builder_kwargs={"version": "2.0.0"}, split=tfds.core.ReadInstruction('train', from_=90, to=100, unit='%')) valid_examples, _ = tfds.load( config.tfds_name, with_info=True, as_supervised=True, data_dir='/content/drive/My Drive/Text_summarization/cnn_dataset', builder_kwargs={"version": "2.0.0"}, split='validation') train_buffer_size = 287113 valid_buffer_size = 13368 else: doc, summ = create_dataframe(file_path.train_csv_path, num_samples_to_train) X_train, X_test, y_train, y_test = train_test_split( doc, summ, test_size=config.test_size, random_state=42) train_examples = tf.data.Dataset.from_tensor_slices((X_train, y_train)) valid_examples = tf.data.Dataset.from_tensor_slices((X_test, y_test)) train_buffer_size = len(X_train) valid_buffer_size = len(X_test) train_dataset = map_batch_shuffle(train_examples, train_buffer_size, split='train', shuffle=shuffle, batch_size=h_parms.batch_size, filter_off=filter_off) valid_dataset = map_batch_shuffle(valid_examples, valid_buffer_size, split='valid', batch_size=h_parms.validation_batch_size, filter_off=filter_off) log.info('Train and Test tf_datasets created') return (train_dataset, valid_dataset, train_buffer_size, valid_buffer_size)
def train_data_from_tfds(): examples, metadata = tfds.load('gigaword', with_info=True, as_supervised=True) train_buffer_size = metadata.splits['train'].num_examples valid_buffer_size = metadata.splits['test'].num_examples train_dataset = batch_shuffle( examples['train'], train_buffer_size, split = 'train', batch_size=config.batch_size ) valid_dataset = batch_shuffle( examples['test'], valid_buffer_size, split='test', batch_size=config.batch_size ) log.info('Training and Test set created') return train_dataset, valid_dataset
def _embedding_from_bert(): log.info("Extracting pretrained word embeddings weights from BERT") #BERT_MODEL_URL = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1" # dinput_word_ids = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32, # name="input_word_ids") # dinput_mask = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32, # name="input_mask") # dsegment_ids = tf.keras.layers.Input(shape=(config.summ_length,), dtype=tf.int32, # name="segment_ids") # einput_word_ids = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32, # name="input_word_ids") # einput_mask = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32, # name="input_mask") # esegment_ids = tf.keras.layers.Input(shape=(config.doc_length,), dtype=tf.int32, # name="segment_ids") #bert_layer = hub.KerasLayer(BERT_MODEL_URL, trainable=False) vocab_of_BERT = TFBertModel.from_pretrained('bert-base-uncased', trainable=False) embedding_matrix = vocab_of_BERT.get_weights()[0] # trainable_vars = vocab_of_BERT.variables # # Remove unused layers # trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name] # # Select how many layers to fine tune # trainable_vars = [] # # Add to trainable weights # for var in trainable_vars: # vocab_of_BERT.trainable_weights.append(var) # for var in vocab_of_BERT.variables: # if var not in vocab_of_BERT.trainable_weights: # vocab_of_BERT.non_trainable_weights.append(var) #_, dsequence_output = vocab_of_BERT([dinput_word_ids, dinput_mask, dsegment_ids]) #_, esequence_output = vocab_of_BERT([einput_word_ids, einput_mask, esegment_ids]) #dec_model = tf.keras.models.Model(inputs=[dinput_word_ids, dinput_mask, dsegment_ids], outputs=dsequence_output) #enc_model = tf.keras.models.Model(inputs=[einput_word_ids, einput_mask, esegment_ids], outputs=esequence_output) log.info(f"Embedding matrix shape '{embedding_matrix.shape}'") return (embedding_matrix, vocab_of_BERT)
def monitor_run(latest_ckpt, ckpt_save_path, val_loss, val_acc, bert_score, rouge_score, valid_summary_writer, step, to_monitor=config.monitor_metric): ckpt_fold, ckpt_string = os.path.split(ckpt_save_path) if config.run_tensorboard: with valid_summary_writer.as_default(): tf.summary.scalar('validation_total_loss', val_acc, step=step) tf.summary.scalar('validation_total_accuracy', val_loss, step=step) tf.summary.scalar('ROUGE_f1', rouge_score, step=step) tf.summary.scalar('BERT_f1', bert_score, step=step) monitor_metrics = dict() monitor_metrics['validation_loss'] = val_loss monitor_metrics['validation_accuracy'] = val_acc monitor_metrics['BERT_f1'] = bert_score monitor_metrics['ROUGE_f1'] = rouge_score monitor_metrics['combined_metric'] = ( monitor_metrics['BERT_f1'], monitor_metrics['ROUGE_f1'], monitor_metrics['validation_accuracy'] ) # multiply with the weights monitor_metrics['combined_metric'] = round(tf.reduce_sum([(i*j) for i,j in zip(monitor_metrics['combined_metric'], h_parms.combined_metric_weights)]).numpy(), 2) log.info(f"combined_metric {monitor_metrics['combined_metric']:4f}") if to_monitor != 'validation_loss': cond = (config.last_recorded_value < monitor_metrics[to_monitor]) else: cond = (config.last_recorded_value > monitor_metrics[monitor]) if (latest_ckpt > config.monitor_only_after) and cond: # reset tolerance to zero if the monitor_metric decreases before the tolerance threshold config.init_tolerance=0 config.last_recorded_value = monitor_metrics[to_monitor] ckpt_files_tocopy = [files for files in os.listdir(os.path.split(ckpt_save_path)[0]) \ if ckpt_string in files] log.info(f'{to_monitor} is {monitor_metrics[to_monitor]:4f} so checkpoint files {ckpt_string} \ will be copied to best checkpoint directory') # copy the best checkpoints shutil.copy2(os.path.join(ckpt_fold, 'checkpoint'), file_path.best_ckpt_path) for files in ckpt_files_tocopy: shutil.copy2(os.path.join(ckpt_fold, files), file_path.best_ckpt_path) else: config.init_tolerance+=1 # Warn and early stop if config.init_tolerance > config.tolerance_threshold: log.warning('Tolerance exceeded') if config.early_stop and config.init_tolerance > config.tolerance_threshold: log.info(f'Early stopping since the {to_monitor} reached the tolerance threshold') return False else: return True
def batch_run_check(batch, start, train_summary_writer, train_loss, train_accuracy, model): if config.run_tensorboard: with train_summary_writer.as_default(): tf.summary.scalar('train_loss', train_loss, step=batch) tf.summary.scalar('train_accuracy', train_accuracy, step=batch) if batch == 0: log.info(model.summary()) log.info(batch_zero.format(time.time() - start)) log.info(batch_run_details.format(batch, train_loss, train_accuracy))
def batch_run_check(batch, epoch, start, train_summary_writer, train_loss, train_accuracy, transformer): if config.run_tensorboard: with train_summary_writer.as_default(): tf.summary.scalar('train_loss', train_loss, step=batch) tf.summary.scalar('train_accuracy', train_accuracy, step=batch) if batch == 0 and epoch == 0: log.info(transformer.summary()) log.info(batch_zero.format(time.time() - start)) if batch % config.print_chks == 0: log.info( batch_run_details.format(epoch + 1, batch, train_loss, train_accuracy))
def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = tf.expand_dims(draft_summary,0) dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( inp, context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask ) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i+1 ,:] if sampling_type == 'nucleus': preds = tf.cast(nucleus_sampling((tf.squeeze(dec_output_i)/ temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast(top_k_sampling((tf.squeeze(dec_output_i)/ temperature), k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling(tf.squeeze(dec_output_i)/ temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) dec_outputs += [dec_output_i] dec_logits_i = dec_logits_i[:, i:i+1, :] dec_logits += [dec_logits_i] refined_summary = with_column(refined_summary, i, preds) attention_dists += [attention_dist[:, i:i+1, :]] cls_concat_dec_outputs = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.target_vocab_size), axis=0), [N, 1, 1])) cls_concat_dec_logits = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.d_model), axis=0), [N, 1, 1])) dec_outputs = tf.reshape(dec_outputs, (1, -1, config.target_vocab_size)) dec_logits = tf.reshape(dec_logits, (1, -1, config.d_model)) attention_dists = tf.reshape(attention_dists, (1, -1, config.doc_length)) dec_outputs = tf.concat([cls_concat_dec_outputs, dec_outputs], axis=1) dec_logits = tf.concat([cls_concat_dec_logits, dec_logits], axis=1) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_logits, dec_outputs, attention_dists, inp, tf.shape(inp)[-1], tf.shape(dec_outputs)[1], training=training ) refined_summary = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return tf.squeeze(refined_summary, axis=0), attention_dist
model) eval_frequency = ((step + 1) * h_parms.batch_size) % config.eval_after if eval_frequency == 0: predicted = (tokenizer.decode([ i for i in tf.squeeze(tf.argmax(refine_predictions, axis=-1)) if i not in [101, 102, 0] ])) target = (tokenizer.decode( [i for i in tf.squeeze(target_x) if i not in [101, 102, 0]])) print(f'the golden summary is {target}') print( f'the predicted summary is {predicted if predicted else "EMPTY"}' ) ckpt_save_path = ck_pt_mgr.save() (val_acc, rouge_score, bert_score) = calc_validation_loss(val_dataset, step + 1, val_step, valid_summary_writer, validation_accuracy) latest_ckpt += (step + 1) log.info( model_metrics.format(step + 1, train_loss.result(), train_accuracy.result(), val_acc, rouge_score, bert_score)) log.info(evaluation_step.format(step + 1, time.time() - start)) log.info(checkpoint_details.format(step + 1, ckpt_save_path)) if not monitor_run(latest_ckpt, ckpt_save_path, val_acc, bert_score, rouge_score, valid_summary_writer, step + 1): break
def tf_encode(doc, summary): return tf.py_function(encode, [doc, summary], [tf.int64, tf.int64]) def batch_shuffle(dataset, buffer_size, split, batch_size=config.batch_size): tf_dataset = dataset.map(tf_encode, num_parallel_calls=AUTOTUNE) tf_dataset = tf_dataset.filter(filter_token_size) sum_of_records = sum(1 for l in tf_dataset) #optimize if sum_of_records > 2,000: tf_dataset = tf_dataset.cache() if buffer_size: tf_dataset = tf_dataset.shuffle(buffer_size, seed = 100) tf_dataset = tf_dataset.padded_batch(batch_size, padded_shapes=([-1], [-1])) tf_dataset = tf_dataset.prefetch(buffer_size=AUTOTUNE) log.info(f'Number of records {split} filtered {buffer_size - sum_of_records}') log.info(f'Number of records to be {split}ed {sum_of_records}') return tf_dataset def train_data_from_tfds(): examples, metadata = tfds.load('gigaword', with_info=True, as_supervised=True) train_buffer_size = metadata.splits['train'].num_examples valid_buffer_size = metadata.splits['test'].num_examples train_dataset = batch_shuffle( examples['train'], train_buffer_size, split = 'train', batch_size=config.batch_size ) valid_dataset = batch_shuffle(
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import shutil import os from configuration import config from hyper_parameters import h_parms from rouge import Rouge from input_path import file_path from create_tokenizer import tokenizer from bert_score import score as b_score from creates import log, monitor_metrics log.info('Loading Pre-trained BERT model for BERT SCORE calculation') _, _, _ = b_score(["I'm Batman"], ["I'm Spiderman"], lang='en', model_type=config.pretrained_bert_model) rouge_all = Rouge() class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): def __init__(self, d_model, warmup_steps=4000): super(CustomSchedule, self).__init__() self.d_model = d_model self.d_model = tf.cast(self.d_model, tf.float32) self.warmup_steps = warmup_steps def __call__(self, step): arg1 = tf.math.rsqrt(step) arg2 = step * (self.warmup_steps ** -1.5) return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
) == 'yes': try: shutil.rmtree(file_path.summary_write_path) shutil.rmtree(file_path.tensorboard_log) except FileNotFoundError: pass train_dataset, val_dataset, num_of_train_examples = create_train_data() train_loss, train_accuracy = get_loss_and_accuracy() validation_loss, validation_accuracy = get_loss_and_accuracy() if config.show_detokenized_samples: inp, tar = next(iter(train_dataset)) for ip, ta in zip(inp.numpy(), tar.numpy()): log.info( tokenizer_en.decode([i for i in ta if i < tokenizer_en.vocab_size])) log.info( tokenizer_en.decode([i for i in ip if i < tokenizer_en.vocab_size])) break transformer = Transformer(num_layers=config.num_layers, d_model=config.d_model, num_heads=config.num_heads, dff=config.dff, input_vocab_size=config.input_vocab_size, target_vocab_size=config.target_vocab_size, rate=h_parms.dropout_rate) generator = Generator()
for epoch in range(h_parms.epochs): start = time.time() train_loss.reset_states() train_accuracy.reset_states() validation_loss.reset_states() validation_accuracy.reset_states() # inp -> document, tar -> summary for (batch, (inp, tar)) in enumerate(train_dataset): # the target is shifted right during training hence its shape is subtracted by 1 # not able to do this inside tf.function since it doesn't allow this operation train_step(inp, tar, inp.shape[1], tar.shape[1] - 1, inp.shape[0]) batch_run_check(batch, epoch, start, train_summary_writer, train_loss.result(), train_accuracy.result(), transformer, pointer_generator) count_recs(batch, epoch, num_of_train_examples) (val_acc, val_loss, rouge_score, bert_score) = calc_validation_loss(val_dataset, epoch + 1, val_step, valid_summary_writer, validation_loss, validation_accuracy) ckpt_save_path = ck_pt_mgr.save() latest_ckpt += epoch log.info( model_metrics.format(epoch + 1, train_loss.result(), train_accuracy.result(), val_loss, val_acc, rouge_score, bert_score)) log.info(epoch_timing.format(epoch + 1, time.time() - start)) log.info(checkpoint_details.format(epoch + 1, ckpt_save_path)) if not monitor_run(latest_ckpt, ckpt_save_path, val_loss, val_acc, bert_score, rouge_score, valid_summary_writer, epoch): break
# if a checkpoint exists, restore the latest checkpoint. ck_pt_mgr, latest_ckpt = check_ckpt(file_path.checkpoint_path) for epoch in range(h_parms.epochs): start = time.time() train_loss.reset_states() train_accuracy.reset_states() validation_loss.reset_states() validation_accuracy.reset_states() # inp -> document, tar -> summary for (batch, (inp, tar)) in enumerate(train_dataset): # the target is shifted right during training hence its shape is subtracted by 1 # not able to do this inside tf.function since it doesn't allow this operation train_step(inp, tar, inp.shape[1], tar.shape[1]-1, inp.shape[0]) if batch==0 and epoch ==0: log.info(transformer.summary()) if config.copy_gen: log.info(pointer_generator.summary()) log.info(batch_zero.format(time.time()-start)) if batch % config.print_chks == 0: log.info(batch_run_details.format( epoch + 1, batch, train_loss.result(), train_accuracy.result())) if epoch == 0: try: if batch > 0: num_of_recs_post_filter_atmost = ((batch)*h_parms.batch_size)/num_of_train_examples num_of_recs_post_filter_atleast = ((batch-1)*h_parms.batch_size)/num_of_train_examples log.info(f'Number of records used for training should be in between {num_of_recs_post_filter_atleast*100} - \ {num_of_recs_post_filter_atmost*100}% of training data') except NameError: assert False, 'Training dataset is empty'