def training_loop(dataset, check_model_capacity, detokenize_samples=None): min_loss = 10000000 if check_model_capacity: dataset = dataset.repeat(670) for (step, (input_ids, target_ids)) in tqdm(enumerate(dataset, 1), initial=1): start=time.time() grad_accum_flag = (True if ((step)%config.gradient_accumulation_steps) == 0 else False) if config.accumulate_gradients else None predictions = train_step( input_ids, target_ids, grad_accum_flag ) if grad_accum_flag is not None: if grad_accum_flag: if (step)%config.steps_to_print_training_info==0: predicted_ids = train_sanity_check(target_tokenizer, predictions, target_ids) train_loss = batch_run_check( step, start ) else: if (step)%config.steps_to_print_training_info==0: train_loss = batch_run_check( step, start ) if check_model_capacity: if min_loss > train_loss: min_loss = train_loss else: log.warning('Loss not decreasing watch out') monitor_early_stop = monitor_run( 'not saving', 0, 0, 0.0, 1, copy_best_ckpt=False ) if check_model_capacity: log.info(f'target_ids are {target_ids}') log.info(f'predicted ids are {predicted_ids}') if train_loss < config.min_train_loss: log.info('Minimum training loss reached') else: log.info("Loss didn't reach upto the min_train_loss specified, try to increase\ the parameters of the model or number of train steps")
ck_pt_mgr = check_ckpt(config.checkpoint_path) total_steps = int(config.epochs * (config.gradient_accumulation_steps)) train_dataset = train_dataset.repeat(total_steps) try: for (step, (input_ids, target_ids)) in tqdm(enumerate(train_dataset, 1), initial=1): if step > 1695000: start_time = time.time() grad_accum_flag = (True if (step % config.gradient_accumulation_steps) == 0 else False) if config.accumulate_gradients else None predictions = train_step(input_ids, target_ids, grad_accum_flag) if (step % config.steps_to_print_training_info) == 0: batch_run_check(step, start_time) if (step % config.eval_after_steps) == 0: (early_stop, draft_attention_weights, refine_attention_weights) = save_evaluate_monitor( ck_pt_mgr, val_dataset, target_tokenizer, predictions, target_ids, step, start_time, return_attention=True) if early_stop: break else: continue
drop_remainder=True) # if a checkpoint exists, restore the latest checkpoint. ck_pt_mgr = check_ckpt(config.checkpoint_path) total_steps = int(config.epochs * (config.gradient_accumulation_steps)) train_dataset = train_dataset.repeat(total_steps) for (step, (input_ids, target_ids_)) in tqdm(enumerate(train_dataset), initial=1): start = time.time() draft_mask, refine_mask, target_ids = mask_and_one_hot_labels(target_ids_) grad_accum_flag = True if ( (step + 1) % config.gradient_accumulation_steps) == 0 else False refine_predictions = train_step(input_ids, target_ids_, target_ids, draft_mask, refine_mask, grad_accum_flag) if grad_accum_flag: train_loss = batch_run_check(step + 1, start) evaluate = ((step + 1) * config.train_batch_size) % config.eval_after if evaluate == 0: predicted = train_sanity_check(target_tokenizer, refine_predictions, target_ids_) ckpt_save_path = ck_pt_mgr.save() if predicted: (rouge_score, bert_score) = evaluate_validation_set(val_dataset, step + 1) else: rouge_score, bert_score = 0 training_results(step + 1, rouge_score, bert_score, (time.time() - start), ckpt_save_path) monitor_early_stop = monitor_run(ckpt_save_path, bert_score, rouge_score, train_loss, step + 1) if not monitor_early_stop:
train_dataset = train_dataset.repeat(total_steps) try: for (step, (input_ids, target_ids)) in tqdm(enumerate(train_dataset, 1), initial=1): if step > 1899567: start_time = time.time() grad_accum_flag = (True if (step%config.gradient_accumulation_steps) == 0 else False) if config.accumulate_gradients else None predictions, bert_f1_score = train_step( input_ids, target_ids, grad_accum_flag ) if (step % config.steps_to_print_training_info) == 0: batch_run_check( step, start_time, bert_f1_score ) if (step % config.eval_after_steps) == 0: (early_stop, draft_attention_weights, refine_attention_weights) = save_evaluate_monitor(ck_pt_mgr, val_dataset, target_tokenizer, predictions, target_ids, step, start_time, bert_f1_score ) if early_stop: break else: early_stop = True else:
from model_training_helper import (check_ckpt, eval_step, train_step, batch_run_check, train_sanity_check) train_dataset = create_dataset(split='train', source_tokenizer=source_tokenizer, target_tokenizer=target_tokenizer, from_=0, to=100, batch_size=2, shuffle=False) # if a checkpoint exists, restore the latest checkpoint. ck_pt_mgr = check_ckpt(config.checkpoint_path) total_steps = int(config.epochs * (config.gradient_accumulation_steps)) train_dataset = train_dataset.repeat(total_steps) stop_at = 1000 for (step, (input_ids, target_ids)) in tqdm(enumerate(train_dataset, 1), initial=1): start_time = time.time() grad_accum_flag = (True if (step % config.gradient_accumulation_steps) == 0 else False) if config.accumulate_gradients else None predictions = train_step(input_ids, target_ids, grad_accum_flag) if (step % config.steps_to_print_training_info) == 0: train_loss = batch_run_check(step, start_time) train_sanity_check(target_tokenizer, predictions, target_ids, log) if (step % stop_at) == 0: break train_sanity_check(target_tokenizer, predictions, target_ids, log) log.info(f'Training completed at step {step}')