def sanity_check(): logger = get_logger() check_multihead_attention(logger) check_smooth_cross_entropy_loss(logger) check_beam_search(logger) check_bleu(logger)
def get_test_dataset(train_root_path: str, src_lang_code: str, tgt_lang_code: str, src_lowercase: bool, tgt_lowercase: bool, src_normalizer: str, tgt_normalizer: str, src_tokenizer: str, tgt_tokenizer: str, shared_vocabulary: bool, test_root_path: str = './data/test'): logger = get_logger() if os.path.isfile('{}.dataset'.format(test_root_path)): logger.info( 'Validation dataset has been already prepared. Loading from binary form ...' ) corpora = Corpora.load('{}.dataset'.format(test_root_path)) logger.info('Done.') return corpora field_specs = [[src_lowercase, src_normalizer, src_tokenizer], [tgt_lowercase, tgt_normalizer, tgt_tokenizer]] vocabularies = None if shared_vocabulary: vocabulary = Vocabulary() vocabulary.load_from_csv('{}/vocab.shared'.format( os.path.dirname(train_root_path))) vocabularies = [vocabulary, vocabulary] else: vocabularies = [ Vocabulary().load_from_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), src_lang_code)), Vocabulary().load_from_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), tgt_lang_code)) ] fields = [ Field(l, v, n, t) for (l, n, t), v in zip(field_specs, vocabularies) ] corpora = Corpora(fields) logger.info('Preparing validation corpora ...') with open('{}.{}'.format(test_root_path, src_lang_code)) as src_stream, \ open('{}.{}'.format(test_root_path, tgt_lang_code)) as tgt_stream: for src_sentence, tgt_sentence in zip(src_stream, tgt_stream): if src_sentence.strip() and tgt_sentence.strip(): corpora.append([src_sentence, tgt_sentence]) logger.info('Saving dataset ...') corpora.save('{}.dataset'.format(test_root_path)) logger.info('Done.') return corpora
def predict( input: Ignore[str], output: Ignore[str], log_prefix: Ignore[str], model: EncoderDecoder = None, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, ): logger = get_logger() (src_vocab, _), (src_field, tgt_field) = get_vocabularies() dataset = Corpora([src_field]) logger.info(f'{log_prefix}: Loading input file ...') with open(input) as src_stream: for src_sentence in src_stream: if src_sentence.strip(): dataset.append([src_sentence]) logger.info(f'{log_prefix}: Loading done.') if model is None: best_model_path = find_best_model() if best_model_path is None: raise RuntimeError( 'Model has not been trained yet. Train the model first.') model = build_model(src_field.vocabulary, tgt_field.vocabulary) state_dict = torch.load(best_model_path) model.load_state_dict(state_dict['model_state']) model.to(get_device()) with open(output, 'w') as output_stream, torch.no_grad(): for batch in dataset.iterate(get_device(), batch_size_limit, batch_limit_by_tokens, sort_by_length=False, shuffle=False): x_mask = batch[0] != src_vocab.pad_index x_mask = x_mask.unsqueeze(1) x_e = model.encode(batch[0], x_mask) y_hat, _ = beam_search(x_e, x_mask, model, get_scores=short_sent_penalty) sentence = src_field.to_sentence_str(batch[0][-1].tolist()) generated = tgt_field.to_sentence_str(y_hat[-1].tolist()) logger.info('SENTENCE:\n ---- {}'.format(sentence)) logger.info('GENERATED:\n ---- {}'.format(generated)) for generated in (src_field.to_sentence_str(s) for s in y_hat.tolist()): output_stream.write(f'{generated}\n')
def train(max_steps: int = 100, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, report_interval_steps: int = 10, validation_interval_steps: int = 100, lr_scheduler_at: str = 'every_step', n_ckpts_to_keep: int = 3, teacher_forcing: bool = True, random_seed: int = 42): set_random_seeds(random_seed) logger = get_logger() train_dataset = get_train_dataset() assert len( train_dataset.fields ) >= 2, "Train dataset must have at least two fields (source and target)." validation_dataset = get_validation_dataset() assert len( validation_dataset.fields ) >= 2, "Validation dataset must have at least two fields (source and target)." loss_function = get_loss_function( train_dataset.fields[1].vocabulary.pad_index) model = build_model(train_dataset.fields[0].vocabulary, train_dataset.fields[1].vocabulary) model.to(get_device()) loss_function.to(get_device()) optimizer = build_optimizer(model.parameters()) scheduler = build_scheduler(optimizer) initialize(model) def noop(): return None def step_lr_scheduler(): return scheduler.step() run_scheduler_at_step = noop run_scheduler_at_validation = noop run_scheduler_at_epoch = noop if scheduler is not None: if lr_scheduler_at == 'every_step': run_scheduler_at_step = step_lr_scheduler elif lr_scheduler_at == 'every_validation': run_scheduler_at_validation = step_lr_scheduler elif lr_scheduler_at == 'every_epoch': run_scheduler_at_epoch = step_lr_scheduler step = 0 epoch = 0 kept_checkpoint_path_score_map = {} best_checkpoint_specs = {"score": -math.inf, "step": -1} @configured('model') def maybe_save_checkpoint(score: Ignore[float], output_path: str): if len(kept_checkpoint_path_score_map) < n_ckpts_to_keep or \ any(score > s for s in kept_checkpoint_path_score_map.values()): if len(kept_checkpoint_path_score_map) >= n_ckpts_to_keep: worst_checkpoint_path = sorted( kept_checkpoint_path_score_map.keys(), key=lambda p: kept_checkpoint_path_score_map[p], reverse=False) worst_checkpoint_path = worst_checkpoint_path[0] kept_checkpoint_path_score_map.pop(worst_checkpoint_path) try: os.unlink(worst_checkpoint_path) except: logger.warn( 'Could not unlink {}.'.format(worst_checkpoint_path)) if score > best_checkpoint_specs["score"]: logger.info( 'New `best model` found with score {:.3f} at step {}.'. format(score, step)) best_checkpoint_specs["score"] = score best_checkpoint_specs["step"] = step state_dict = { "step": step, "best_checkpoint_specs": best_checkpoint_specs, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict() if scheduler is not None else None } checkpoint_path = '{}/step_{}_score_{:.3f}.pt'.format( output_path, step, score) torch.save(state_dict, checkpoint_path) kept_checkpoint_path_score_map[checkpoint_path] = score model.train() validation_done_already = False while step < max_steps: start_time = time.time() total_tokens_processed = 0 for batch in train_dataset.iterate(get_device(), batch_size_limit, batch_limit_by_tokens): step += 1 if step >= max_steps: break x_mask = batch[0] != model.src_vocab.pad_index x_mask = x_mask.unsqueeze(1) y_mask = batch[1] != model.tgt_vocab.pad_index y_mask = y_mask.unsqueeze(1) x_e = model.encode(batch[0], x_mask) log_probs = model.decode(batch[1][:, :-1], x_e, y_mask[:, :, :-1], x_mask, teacher_forcing=teacher_forcing) token_count = y_mask[:, :, 1:].sum().item() loss = loss_function(log_probs, batch[1][:, 1:], model.get_target_embeddings()) / token_count loss.backward() optimizer.step() mark_optimization_step() optimizer.zero_grad() run_scheduler_at_step() total_tokens_processed += token_count if step > 0 and step % report_interval_steps == 0: elapsed_time = time.time() - start_time baseline_loss = loss_function.uniform_baseline_loss( log_probs, batch[1][:, 1:]) logger.info( 'Epoch_{} Step_{}: loss={:.3f}(vs {:.3f} uniform), tokens/s={:.1f}, lr={}' .format(epoch, step, loss.item(), baseline_loss, total_tokens_processed / elapsed_time, optimizer.param_groups[0]['lr'])) start_time = time.time() total_tokens_processed = 0 if step > 0 and step % validation_interval_steps == 0: log_prefix = 'Epoch_{} Step_{}'.format(epoch, step) score = evaluate(validation_dataset, log_prefix, model, loss_function) maybe_save_checkpoint(score) model.train() run_scheduler_at_validation() start_time = time.time() total_tokens_processed = 0 validation_done_already = True else: validation_done_already = False epoch += 1 logger.info('Epoch {} finished.'.format(epoch)) run_scheduler_at_epoch() if not validation_done_already: log_prefix = 'Final (epoch={} ~ step={})'.format(epoch, step) score = evaluate(validation_dataset, log_prefix, model, loss_function) maybe_save_checkpoint(score) logger.info('Best validation loss was {:.3f} at step {}.'.format( best_checkpoint_specs["score"], best_checkpoint_specs["step"]))
def get_train_dataset(train_root_path: str = './data/train', src_lang_code: str = 'en', tgt_lang_code: str = 'fa', src_lowercase: bool = False, tgt_lowercase: bool = False, src_normalizer: str = 'default', tgt_normalizer: str = 'default', src_tokenizer: str = 'default', tgt_tokenizer: str = 'default', force_vocab_update: bool = True, shared_vocabulary: bool = False): logger = get_logger() if os.path.isfile('{}.dataset'.format(train_root_path)): logger.info( 'Train dataset has been already prepared. Loading from binary form ...' ) corpora = Corpora.load('{}.dataset'.format(train_root_path)) logger.info('Done.') return corpora field_specs = [[src_lowercase, src_normalizer, src_tokenizer], [tgt_lowercase, tgt_normalizer, tgt_tokenizer]] vocabulary = None if shared_vocabulary: vocabulary = Vocabulary() fields = [ Field(l, vocabulary, n, t, force_vocabulary_update=force_vocab_update) for l, n, t in field_specs ] if not force_vocab_update: logger.info('Not updating provided vocabularies ...') if shared_vocabulary: vocabulary.load_from_csv('{}/vocab.shared'.format( os.path.dirname(train_root_path))) else: fields[0].vocabulary.load_from_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), src_lang_code)) fields[1].vocabulary.load_from_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), tgt_lang_code)) corpora = Corpora(fields) logger.info('Preparing train corpora ...') with open('{}.{}'.format(train_root_path, src_lang_code)) as src_stream, \ open('{}.{}'.format(train_root_path, tgt_lang_code)) as tgt_stream: for src_sentence, tgt_sentence in zip(src_stream, tgt_stream): if src_sentence.strip() and tgt_sentence.strip(): corpora.append([src_sentence, tgt_sentence]) logger.info('Saving dataset ...') corpora.save('{}.dataset'.format(train_root_path)) logger.info('Saving vocabular(y|ies) ...') if shared_vocabulary: vocabulary.save_as_csv('{}/vocab.shared'.format( os.path.dirname(train_root_path))) else: fields[0].vocabulary.save_as_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), src_lang_code)) fields[1].vocabulary.save_as_csv('{}/vocab.{}'.format( os.path.dirname(train_root_path), tgt_lang_code)) logger.info('Done.') return corpora
def evaluate(validation_dataset: Corpora, log_prefix: Ignore[str], model: EncoderDecoder = None, loss_function: Callable = None, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, teacher_forcing: bool = True, metrics: Tuple[Metric] = None): assert len( validation_dataset.fields ) >= 2, "Validation dataset must have at least two fields (source and target)." logger = get_logger() if loss_function is None: loss_function = get_loss_function( validation_dataset.fields[1].vocabulary.pad_index) loss_function.to(get_device()) if model is None: best_model_path = find_best_model() if best_model_path is None: raise RuntimeError( 'Model has not been trained yet. Train the model first.') model = build_model(validation_dataset.fields[0].vocabulary, validation_dataset.fields[1].vocabulary) state_dict = torch.load(best_model_path) model.load_state_dict(state_dict['model_state']) model.to(get_device()) pad_index = model.tgt_vocab.pad_index total_item_count = 0 total_validation_loss = 0 model.eval() printed_samples = 0 if metrics is None: metrics = (BleuMetric(), ) else: metrics = (BleuMetric(), ) + tuple( m for m in metrics if not isinstance(m, BleuMetric)) with torch.no_grad(): start_time = time.time() for validation_batch in validation_dataset.iterate( get_device(), batch_size_limit, batch_limit_by_tokens, sort_by_length=False, shuffle=False): x_mask = validation_batch[0] != model.src_vocab.pad_index x_mask = x_mask.unsqueeze(1) y_mask = validation_batch[1] != model.tgt_vocab.pad_index y_mask = y_mask.unsqueeze(1) x_e = model.encode(validation_batch[0], x_mask) log_probs = model.decode(validation_batch[1][:, :-1], x_e, y_mask[:, :, :-1], x_mask, teacher_forcing=teacher_forcing) loss = loss_function(log_probs, validation_batch[1][:, 1:], model.get_target_embeddings()) total_item_count += y_mask[:, :, 1:].sum().item() total_validation_loss += loss.item() y_hat, _ = beam_search(x_e, x_mask, model, get_scores=short_sent_penalty) if printed_samples < 4: sentence = validation_dataset.fields[0].to_sentence_str( validation_batch[0][-1].tolist()) reference = validation_dataset.fields[1].to_sentence_str( validation_batch[1][-1].tolist()) generated = validation_dataset.fields[1].to_sentence_str( y_hat[-1].tolist()) logger.info('SENTENCE:\n ---- {}'.format(sentence)) logger.info('REFERENCE:\n ---- {}'.format(reference)) logger.info('GENERATED:\n ---- {}'.format(generated)) printed_samples += 1 update_metric_params(y_hat, validation_batch[1], pad_index, metrics) elapsed_time = time.time() - start_time logger.info( f'{log_prefix}: ' f'evaluation_loss={total_validation_loss / total_item_count:.3f}, ' f'elapsed_time={int(elapsed_time + 0.5)}s') for metric_repr in (str(m) for m in metrics): logger.info(f'{log_prefix}: evaluation {metric_repr}') return metrics[0].get_score()