示例#1
0
def sanity_check():
    logger = get_logger()

    check_multihead_attention(logger)
    check_smooth_cross_entropy_loss(logger)
    check_beam_search(logger)
    check_bleu(logger)
示例#2
0
def get_test_dataset(train_root_path: str,
                     src_lang_code: str,
                     tgt_lang_code: str,
                     src_lowercase: bool,
                     tgt_lowercase: bool,
                     src_normalizer: str,
                     tgt_normalizer: str,
                     src_tokenizer: str,
                     tgt_tokenizer: str,
                     shared_vocabulary: bool,
                     test_root_path: str = './data/test'):
    logger = get_logger()

    if os.path.isfile('{}.dataset'.format(test_root_path)):
        logger.info(
            'Validation dataset has been already prepared. Loading from binary form ...'
        )
        corpora = Corpora.load('{}.dataset'.format(test_root_path))
        logger.info('Done.')
        return corpora

    field_specs = [[src_lowercase, src_normalizer, src_tokenizer],
                   [tgt_lowercase, tgt_normalizer, tgt_tokenizer]]

    vocabularies = None
    if shared_vocabulary:
        vocabulary = Vocabulary()
        vocabulary.load_from_csv('{}/vocab.shared'.format(
            os.path.dirname(train_root_path)))
        vocabularies = [vocabulary, vocabulary]
    else:
        vocabularies = [
            Vocabulary().load_from_csv('{}/vocab.{}'.format(
                os.path.dirname(train_root_path), src_lang_code)),
            Vocabulary().load_from_csv('{}/vocab.{}'.format(
                os.path.dirname(train_root_path), tgt_lang_code))
        ]

    fields = [
        Field(l, v, n, t) for (l, n, t), v in zip(field_specs, vocabularies)
    ]

    corpora = Corpora(fields)

    logger.info('Preparing validation corpora ...')
    with open('{}.{}'.format(test_root_path, src_lang_code)) as src_stream, \
            open('{}.{}'.format(test_root_path, tgt_lang_code)) as tgt_stream:
        for src_sentence, tgt_sentence in zip(src_stream, tgt_stream):
            if src_sentence.strip() and tgt_sentence.strip():
                corpora.append([src_sentence, tgt_sentence])
    logger.info('Saving dataset ...')
    corpora.save('{}.dataset'.format(test_root_path))
    logger.info('Done.')

    return corpora
示例#3
0
def predict(
    input: Ignore[str],
    output: Ignore[str],
    log_prefix: Ignore[str],
    model: EncoderDecoder = None,
    batch_size_limit: int = 400,
    batch_limit_by_tokens: bool = True,
):

    logger = get_logger()

    (src_vocab, _), (src_field, tgt_field) = get_vocabularies()

    dataset = Corpora([src_field])
    logger.info(f'{log_prefix}: Loading input file ...')
    with open(input) as src_stream:
        for src_sentence in src_stream:
            if src_sentence.strip():
                dataset.append([src_sentence])
    logger.info(f'{log_prefix}: Loading done.')

    if model is None:
        best_model_path = find_best_model()
        if best_model_path is None:
            raise RuntimeError(
                'Model has not been trained yet. Train the model first.')
        model = build_model(src_field.vocabulary, tgt_field.vocabulary)
        state_dict = torch.load(best_model_path)
        model.load_state_dict(state_dict['model_state'])
        model.to(get_device())

    with open(output, 'w') as output_stream, torch.no_grad():

        for batch in dataset.iterate(get_device(),
                                     batch_size_limit,
                                     batch_limit_by_tokens,
                                     sort_by_length=False,
                                     shuffle=False):
            x_mask = batch[0] != src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            x_e = model.encode(batch[0], x_mask)
            y_hat, _ = beam_search(x_e,
                                   x_mask,
                                   model,
                                   get_scores=short_sent_penalty)
            sentence = src_field.to_sentence_str(batch[0][-1].tolist())
            generated = tgt_field.to_sentence_str(y_hat[-1].tolist())

            logger.info('SENTENCE:\n ---- {}'.format(sentence))
            logger.info('GENERATED:\n ---- {}'.format(generated))

            for generated in (src_field.to_sentence_str(s)
                              for s in y_hat.tolist()):
                output_stream.write(f'{generated}\n')
示例#4
0
def train(max_steps: int = 100,
          batch_size_limit: int = 400,
          batch_limit_by_tokens: bool = True,
          report_interval_steps: int = 10,
          validation_interval_steps: int = 100,
          lr_scheduler_at: str = 'every_step',
          n_ckpts_to_keep: int = 3,
          teacher_forcing: bool = True,
          random_seed: int = 42):

    set_random_seeds(random_seed)
    logger = get_logger()

    train_dataset = get_train_dataset()
    assert len(
        train_dataset.fields
    ) >= 2, "Train dataset must have at least two fields (source and target)."
    validation_dataset = get_validation_dataset()
    assert len(
        validation_dataset.fields
    ) >= 2, "Validation dataset must have at least two fields (source and target)."

    loss_function = get_loss_function(
        train_dataset.fields[1].vocabulary.pad_index)

    model = build_model(train_dataset.fields[0].vocabulary,
                        train_dataset.fields[1].vocabulary)

    model.to(get_device())
    loss_function.to(get_device())

    optimizer = build_optimizer(model.parameters())
    scheduler = build_scheduler(optimizer)

    initialize(model)

    def noop():
        return None

    def step_lr_scheduler():
        return scheduler.step()

    run_scheduler_at_step = noop
    run_scheduler_at_validation = noop
    run_scheduler_at_epoch = noop

    if scheduler is not None:
        if lr_scheduler_at == 'every_step':
            run_scheduler_at_step = step_lr_scheduler
        elif lr_scheduler_at == 'every_validation':
            run_scheduler_at_validation = step_lr_scheduler
        elif lr_scheduler_at == 'every_epoch':
            run_scheduler_at_epoch = step_lr_scheduler

    step = 0
    epoch = 0

    kept_checkpoint_path_score_map = {}

    best_checkpoint_specs = {"score": -math.inf, "step": -1}

    @configured('model')
    def maybe_save_checkpoint(score: Ignore[float], output_path: str):

        if len(kept_checkpoint_path_score_map) < n_ckpts_to_keep or \
                any(score > s for s in kept_checkpoint_path_score_map.values()):
            if len(kept_checkpoint_path_score_map) >= n_ckpts_to_keep:
                worst_checkpoint_path = sorted(
                    kept_checkpoint_path_score_map.keys(),
                    key=lambda p: kept_checkpoint_path_score_map[p],
                    reverse=False)
                worst_checkpoint_path = worst_checkpoint_path[0]
                kept_checkpoint_path_score_map.pop(worst_checkpoint_path)
                try:
                    os.unlink(worst_checkpoint_path)
                except:
                    logger.warn(
                        'Could not unlink {}.'.format(worst_checkpoint_path))

            if score > best_checkpoint_specs["score"]:
                logger.info(
                    'New `best model` found with score {:.3f} at step {}.'.
                    format(score, step))
                best_checkpoint_specs["score"] = score
                best_checkpoint_specs["step"] = step

            state_dict = {
                "step":
                step,
                "best_checkpoint_specs":
                best_checkpoint_specs,
                "model_state":
                model.state_dict(),
                "optimizer_state":
                optimizer.state_dict(),
                "scheduler_state":
                scheduler.state_dict() if scheduler is not None else None
            }
            checkpoint_path = '{}/step_{}_score_{:.3f}.pt'.format(
                output_path, step, score)
            torch.save(state_dict, checkpoint_path)
            kept_checkpoint_path_score_map[checkpoint_path] = score

    model.train()

    validation_done_already = False
    while step < max_steps:

        start_time = time.time()
        total_tokens_processed = 0
        for batch in train_dataset.iterate(get_device(), batch_size_limit,
                                           batch_limit_by_tokens):
            step += 1
            if step >= max_steps:
                break

            x_mask = batch[0] != model.src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            y_mask = batch[1] != model.tgt_vocab.pad_index
            y_mask = y_mask.unsqueeze(1)

            x_e = model.encode(batch[0], x_mask)
            log_probs = model.decode(batch[1][:, :-1],
                                     x_e,
                                     y_mask[:, :, :-1],
                                     x_mask,
                                     teacher_forcing=teacher_forcing)
            token_count = y_mask[:, :, 1:].sum().item()
            loss = loss_function(log_probs, batch[1][:, 1:],
                                 model.get_target_embeddings()) / token_count
            loss.backward()

            optimizer.step()
            mark_optimization_step()
            optimizer.zero_grad()

            run_scheduler_at_step()

            total_tokens_processed += token_count

            if step > 0 and step % report_interval_steps == 0:
                elapsed_time = time.time() - start_time
                baseline_loss = loss_function.uniform_baseline_loss(
                    log_probs, batch[1][:, 1:])
                logger.info(
                    'Epoch_{} Step_{}: loss={:.3f}(vs {:.3f} uniform), tokens/s={:.1f}, lr={}'
                    .format(epoch, step, loss.item(), baseline_loss,
                            total_tokens_processed / elapsed_time,
                            optimizer.param_groups[0]['lr']))
                start_time = time.time()
                total_tokens_processed = 0

            if step > 0 and step % validation_interval_steps == 0:
                log_prefix = 'Epoch_{} Step_{}'.format(epoch, step)
                score = evaluate(validation_dataset, log_prefix, model,
                                 loss_function)
                maybe_save_checkpoint(score)
                model.train()
                run_scheduler_at_validation()
                start_time = time.time()
                total_tokens_processed = 0
                validation_done_already = True
            else:
                validation_done_already = False

        epoch += 1
        logger.info('Epoch {} finished.'.format(epoch))
        run_scheduler_at_epoch()

    if not validation_done_already:
        log_prefix = 'Final (epoch={} ~ step={})'.format(epoch, step)
        score = evaluate(validation_dataset, log_prefix, model, loss_function)
        maybe_save_checkpoint(score)
    logger.info('Best validation loss was {:.3f} at step {}.'.format(
        best_checkpoint_specs["score"], best_checkpoint_specs["step"]))
示例#5
0
def get_train_dataset(train_root_path: str = './data/train',
                      src_lang_code: str = 'en',
                      tgt_lang_code: str = 'fa',
                      src_lowercase: bool = False,
                      tgt_lowercase: bool = False,
                      src_normalizer: str = 'default',
                      tgt_normalizer: str = 'default',
                      src_tokenizer: str = 'default',
                      tgt_tokenizer: str = 'default',
                      force_vocab_update: bool = True,
                      shared_vocabulary: bool = False):
    logger = get_logger()

    if os.path.isfile('{}.dataset'.format(train_root_path)):
        logger.info(
            'Train dataset has been already prepared. Loading from binary form ...'
        )
        corpora = Corpora.load('{}.dataset'.format(train_root_path))
        logger.info('Done.')
        return corpora

    field_specs = [[src_lowercase, src_normalizer, src_tokenizer],
                   [tgt_lowercase, tgt_normalizer, tgt_tokenizer]]

    vocabulary = None
    if shared_vocabulary:
        vocabulary = Vocabulary()

    fields = [
        Field(l, vocabulary, n, t, force_vocabulary_update=force_vocab_update)
        for l, n, t in field_specs
    ]

    if not force_vocab_update:
        logger.info('Not updating provided vocabularies ...')
        if shared_vocabulary:
            vocabulary.load_from_csv('{}/vocab.shared'.format(
                os.path.dirname(train_root_path)))
        else:
            fields[0].vocabulary.load_from_csv('{}/vocab.{}'.format(
                os.path.dirname(train_root_path), src_lang_code))
            fields[1].vocabulary.load_from_csv('{}/vocab.{}'.format(
                os.path.dirname(train_root_path), tgt_lang_code))

    corpora = Corpora(fields)

    logger.info('Preparing train corpora ...')
    with open('{}.{}'.format(train_root_path, src_lang_code)) as src_stream, \
            open('{}.{}'.format(train_root_path, tgt_lang_code)) as tgt_stream:
        for src_sentence, tgt_sentence in zip(src_stream, tgt_stream):
            if src_sentence.strip() and tgt_sentence.strip():
                corpora.append([src_sentence, tgt_sentence])
    logger.info('Saving dataset ...')
    corpora.save('{}.dataset'.format(train_root_path))
    logger.info('Saving vocabular(y|ies) ...')
    if shared_vocabulary:
        vocabulary.save_as_csv('{}/vocab.shared'.format(
            os.path.dirname(train_root_path)))
    else:
        fields[0].vocabulary.save_as_csv('{}/vocab.{}'.format(
            os.path.dirname(train_root_path), src_lang_code))
        fields[1].vocabulary.save_as_csv('{}/vocab.{}'.format(
            os.path.dirname(train_root_path), tgt_lang_code))
    logger.info('Done.')

    return corpora
示例#6
0
def evaluate(validation_dataset: Corpora,
             log_prefix: Ignore[str],
             model: EncoderDecoder = None,
             loss_function: Callable = None,
             batch_size_limit: int = 400,
             batch_limit_by_tokens: bool = True,
             teacher_forcing: bool = True,
             metrics: Tuple[Metric] = None):
    assert len(
        validation_dataset.fields
    ) >= 2, "Validation dataset must have at least two fields (source and target)."

    logger = get_logger()

    if loss_function is None:
        loss_function = get_loss_function(
            validation_dataset.fields[1].vocabulary.pad_index)
        loss_function.to(get_device())
    if model is None:
        best_model_path = find_best_model()
        if best_model_path is None:
            raise RuntimeError(
                'Model has not been trained yet. Train the model first.')
        model = build_model(validation_dataset.fields[0].vocabulary,
                            validation_dataset.fields[1].vocabulary)
        state_dict = torch.load(best_model_path)
        model.load_state_dict(state_dict['model_state'])
        model.to(get_device())
    pad_index = model.tgt_vocab.pad_index

    total_item_count = 0
    total_validation_loss = 0
    model.eval()

    printed_samples = 0

    if metrics is None:
        metrics = (BleuMetric(), )
    else:
        metrics = (BleuMetric(), ) + tuple(
            m for m in metrics if not isinstance(m, BleuMetric))

    with torch.no_grad():

        start_time = time.time()
        for validation_batch in validation_dataset.iterate(
                get_device(),
                batch_size_limit,
                batch_limit_by_tokens,
                sort_by_length=False,
                shuffle=False):
            x_mask = validation_batch[0] != model.src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            y_mask = validation_batch[1] != model.tgt_vocab.pad_index
            y_mask = y_mask.unsqueeze(1)

            x_e = model.encode(validation_batch[0], x_mask)
            log_probs = model.decode(validation_batch[1][:, :-1],
                                     x_e,
                                     y_mask[:, :, :-1],
                                     x_mask,
                                     teacher_forcing=teacher_forcing)

            loss = loss_function(log_probs, validation_batch[1][:, 1:],
                                 model.get_target_embeddings())
            total_item_count += y_mask[:, :, 1:].sum().item()
            total_validation_loss += loss.item()

            y_hat, _ = beam_search(x_e,
                                   x_mask,
                                   model,
                                   get_scores=short_sent_penalty)

            if printed_samples < 4:
                sentence = validation_dataset.fields[0].to_sentence_str(
                    validation_batch[0][-1].tolist())
                reference = validation_dataset.fields[1].to_sentence_str(
                    validation_batch[1][-1].tolist())
                generated = validation_dataset.fields[1].to_sentence_str(
                    y_hat[-1].tolist())
                logger.info('SENTENCE:\n ---- {}'.format(sentence))
                logger.info('REFERENCE:\n ---- {}'.format(reference))
                logger.info('GENERATED:\n ---- {}'.format(generated))

                printed_samples += 1

            update_metric_params(y_hat, validation_batch[1], pad_index,
                                 metrics)

    elapsed_time = time.time() - start_time
    logger.info(
        f'{log_prefix}: '
        f'evaluation_loss={total_validation_loss / total_item_count:.3f}, '
        f'elapsed_time={int(elapsed_time + 0.5)}s')
    for metric_repr in (str(m) for m in metrics):
        logger.info(f'{log_prefix}: evaluation {metric_repr}')

    return metrics[0].get_score()