Esempio n. 1
0
def predict(
    input: Ignore[str],
    output: Ignore[str],
    log_prefix: Ignore[str],
    model: EncoderDecoder = None,
    batch_size_limit: int = 400,
    batch_limit_by_tokens: bool = True,
):

    logger = get_logger()

    (src_vocab, _), (src_field, tgt_field) = get_vocabularies()

    dataset = Corpora([src_field])
    logger.info(f'{log_prefix}: Loading input file ...')
    with open(input) as src_stream:
        for src_sentence in src_stream:
            if src_sentence.strip():
                dataset.append([src_sentence])
    logger.info(f'{log_prefix}: Loading done.')

    if model is None:
        best_model_path = find_best_model()
        if best_model_path is None:
            raise RuntimeError(
                'Model has not been trained yet. Train the model first.')
        model = build_model(src_field.vocabulary, tgt_field.vocabulary)
        state_dict = torch.load(best_model_path)
        model.load_state_dict(state_dict['model_state'])
        model.to(get_device())

    with open(output, 'w') as output_stream, torch.no_grad():

        for batch in dataset.iterate(get_device(),
                                     batch_size_limit,
                                     batch_limit_by_tokens,
                                     sort_by_length=False,
                                     shuffle=False):
            x_mask = batch[0] != src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            x_e = model.encode(batch[0], x_mask)
            y_hat, _ = beam_search(x_e,
                                   x_mask,
                                   model,
                                   get_scores=short_sent_penalty)
            sentence = src_field.to_sentence_str(batch[0][-1].tolist())
            generated = tgt_field.to_sentence_str(y_hat[-1].tolist())

            logger.info('SENTENCE:\n ---- {}'.format(sentence))
            logger.info('GENERATED:\n ---- {}'.format(generated))

            for generated in (src_field.to_sentence_str(s)
                              for s in y_hat.tolist()):
                output_stream.write(f'{generated}\n')
Esempio n. 2
0
    def __init__(self,
                 pad_index: Ignore[int],
                 teacher_config_path: str = '/DOES_NOT_EXIST'):
        teacher_config_path = relative_to_config_path(teacher_config_path)
        assert os.path.exists(
            teacher_config_path), "Teacher model config does not exist."
        nn.Module.__init__(self)
        teacher_model_config = get_configuration().clone()
        with open(teacher_config_path) as f:
            teacher_model_config.load(json.load(f))

        push_configuration(teacher_model_config)
        update_and_ensure_model_output_path('test', None)

        best_model_path = find_best_model()
        if best_model_path is None:
            raise ValueError('Could not find the teacher model.')
        (src_vocab, tgt_vocab), _ = get_vocabularies()
        self.teacher_model = build_model(src_vocab, tgt_vocab)
        state_dict = torch.load(best_model_path)
        self.teacher_model.load_state_dict(state_dict['model_state'])
        self.teacher_model.to(get_device())
        self.teacher_model.eval()
        self.src_pad_index = src_vocab.pad_index
        self.tgt_pad_index = tgt_vocab.pad_index

        pop_configuration()

        self.pad_index = pad_index
Esempio n. 3
0
 def init_teacher_(model_emb, teacher_emb):
     assert model_emb.size(0) == teacher_emb.size(
         0), "Vocabulary sizes are different."
     if model_emb.size(1) == teacher_emb.size(1):
         model_emb.copy_(teacher_emb.to(model_emb.device))
     elif model_emb.size(1) > teacher_emb.size(1):
         model_emb.narrow(1, 0, teacher_emb.size(1)).copy_(
             teacher_emb.to(model_emb.device))
     else:
         teacher_emb = teacher_emb.to(get_device())
         print(teacher_emb.size())
         print(torch.matmul(teacher_emb.t(), teacher_emb).size())
         _, phi = torch.symeig(torch.matmul(teacher_emb.t(), teacher_emb),
                               eigenvectors=True)
         model_emb.copy_(
             torch.matmul(teacher_emb,
                          phi[:, :model_emb.size(1)]).to(model_emb.device))
Esempio n. 4
0
def train(max_steps: int = 100,
          batch_size_limit: int = 400,
          batch_limit_by_tokens: bool = True,
          report_interval_steps: int = 10,
          validation_interval_steps: int = 100,
          lr_scheduler_at: str = 'every_step',
          n_ckpts_to_keep: int = 3,
          teacher_forcing: bool = True,
          random_seed: int = 42):

    set_random_seeds(random_seed)
    logger = get_logger()

    train_dataset = get_train_dataset()
    assert len(
        train_dataset.fields
    ) >= 2, "Train dataset must have at least two fields (source and target)."
    validation_dataset = get_validation_dataset()
    assert len(
        validation_dataset.fields
    ) >= 2, "Validation dataset must have at least two fields (source and target)."

    loss_function = get_loss_function(
        train_dataset.fields[1].vocabulary.pad_index)

    model = build_model(train_dataset.fields[0].vocabulary,
                        train_dataset.fields[1].vocabulary)

    model.to(get_device())
    loss_function.to(get_device())

    optimizer = build_optimizer(model.parameters())
    scheduler = build_scheduler(optimizer)

    initialize(model)

    def noop():
        return None

    def step_lr_scheduler():
        return scheduler.step()

    run_scheduler_at_step = noop
    run_scheduler_at_validation = noop
    run_scheduler_at_epoch = noop

    if scheduler is not None:
        if lr_scheduler_at == 'every_step':
            run_scheduler_at_step = step_lr_scheduler
        elif lr_scheduler_at == 'every_validation':
            run_scheduler_at_validation = step_lr_scheduler
        elif lr_scheduler_at == 'every_epoch':
            run_scheduler_at_epoch = step_lr_scheduler

    step = 0
    epoch = 0

    kept_checkpoint_path_score_map = {}

    best_checkpoint_specs = {"score": -math.inf, "step": -1}

    @configured('model')
    def maybe_save_checkpoint(score: Ignore[float], output_path: str):

        if len(kept_checkpoint_path_score_map) < n_ckpts_to_keep or \
                any(score > s for s in kept_checkpoint_path_score_map.values()):
            if len(kept_checkpoint_path_score_map) >= n_ckpts_to_keep:
                worst_checkpoint_path = sorted(
                    kept_checkpoint_path_score_map.keys(),
                    key=lambda p: kept_checkpoint_path_score_map[p],
                    reverse=False)
                worst_checkpoint_path = worst_checkpoint_path[0]
                kept_checkpoint_path_score_map.pop(worst_checkpoint_path)
                try:
                    os.unlink(worst_checkpoint_path)
                except:
                    logger.warn(
                        'Could not unlink {}.'.format(worst_checkpoint_path))

            if score > best_checkpoint_specs["score"]:
                logger.info(
                    'New `best model` found with score {:.3f} at step {}.'.
                    format(score, step))
                best_checkpoint_specs["score"] = score
                best_checkpoint_specs["step"] = step

            state_dict = {
                "step":
                step,
                "best_checkpoint_specs":
                best_checkpoint_specs,
                "model_state":
                model.state_dict(),
                "optimizer_state":
                optimizer.state_dict(),
                "scheduler_state":
                scheduler.state_dict() if scheduler is not None else None
            }
            checkpoint_path = '{}/step_{}_score_{:.3f}.pt'.format(
                output_path, step, score)
            torch.save(state_dict, checkpoint_path)
            kept_checkpoint_path_score_map[checkpoint_path] = score

    model.train()

    validation_done_already = False
    while step < max_steps:

        start_time = time.time()
        total_tokens_processed = 0
        for batch in train_dataset.iterate(get_device(), batch_size_limit,
                                           batch_limit_by_tokens):
            step += 1
            if step >= max_steps:
                break

            x_mask = batch[0] != model.src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            y_mask = batch[1] != model.tgt_vocab.pad_index
            y_mask = y_mask.unsqueeze(1)

            x_e = model.encode(batch[0], x_mask)
            log_probs = model.decode(batch[1][:, :-1],
                                     x_e,
                                     y_mask[:, :, :-1],
                                     x_mask,
                                     teacher_forcing=teacher_forcing)
            token_count = y_mask[:, :, 1:].sum().item()
            loss = loss_function(log_probs, batch[1][:, 1:],
                                 model.get_target_embeddings()) / token_count
            loss.backward()

            optimizer.step()
            mark_optimization_step()
            optimizer.zero_grad()

            run_scheduler_at_step()

            total_tokens_processed += token_count

            if step > 0 and step % report_interval_steps == 0:
                elapsed_time = time.time() - start_time
                baseline_loss = loss_function.uniform_baseline_loss(
                    log_probs, batch[1][:, 1:])
                logger.info(
                    'Epoch_{} Step_{}: loss={:.3f}(vs {:.3f} uniform), tokens/s={:.1f}, lr={}'
                    .format(epoch, step, loss.item(), baseline_loss,
                            total_tokens_processed / elapsed_time,
                            optimizer.param_groups[0]['lr']))
                start_time = time.time()
                total_tokens_processed = 0

            if step > 0 and step % validation_interval_steps == 0:
                log_prefix = 'Epoch_{} Step_{}'.format(epoch, step)
                score = evaluate(validation_dataset, log_prefix, model,
                                 loss_function)
                maybe_save_checkpoint(score)
                model.train()
                run_scheduler_at_validation()
                start_time = time.time()
                total_tokens_processed = 0
                validation_done_already = True
            else:
                validation_done_already = False

        epoch += 1
        logger.info('Epoch {} finished.'.format(epoch))
        run_scheduler_at_epoch()

    if not validation_done_already:
        log_prefix = 'Final (epoch={} ~ step={})'.format(epoch, step)
        score = evaluate(validation_dataset, log_prefix, model, loss_function)
        maybe_save_checkpoint(score)
    logger.info('Best validation loss was {:.3f} at step {}.'.format(
        best_checkpoint_specs["score"], best_checkpoint_specs["step"]))
Esempio n. 5
0
def evaluate(validation_dataset: Corpora,
             log_prefix: Ignore[str],
             model: EncoderDecoder = None,
             loss_function: Callable = None,
             batch_size_limit: int = 400,
             batch_limit_by_tokens: bool = True,
             teacher_forcing: bool = True,
             metrics: Tuple[Metric] = None):
    assert len(
        validation_dataset.fields
    ) >= 2, "Validation dataset must have at least two fields (source and target)."

    logger = get_logger()

    if loss_function is None:
        loss_function = get_loss_function(
            validation_dataset.fields[1].vocabulary.pad_index)
        loss_function.to(get_device())
    if model is None:
        best_model_path = find_best_model()
        if best_model_path is None:
            raise RuntimeError(
                'Model has not been trained yet. Train the model first.')
        model = build_model(validation_dataset.fields[0].vocabulary,
                            validation_dataset.fields[1].vocabulary)
        state_dict = torch.load(best_model_path)
        model.load_state_dict(state_dict['model_state'])
        model.to(get_device())
    pad_index = model.tgt_vocab.pad_index

    total_item_count = 0
    total_validation_loss = 0
    model.eval()

    printed_samples = 0

    if metrics is None:
        metrics = (BleuMetric(), )
    else:
        metrics = (BleuMetric(), ) + tuple(
            m for m in metrics if not isinstance(m, BleuMetric))

    with torch.no_grad():

        start_time = time.time()
        for validation_batch in validation_dataset.iterate(
                get_device(),
                batch_size_limit,
                batch_limit_by_tokens,
                sort_by_length=False,
                shuffle=False):
            x_mask = validation_batch[0] != model.src_vocab.pad_index
            x_mask = x_mask.unsqueeze(1)

            y_mask = validation_batch[1] != model.tgt_vocab.pad_index
            y_mask = y_mask.unsqueeze(1)

            x_e = model.encode(validation_batch[0], x_mask)
            log_probs = model.decode(validation_batch[1][:, :-1],
                                     x_e,
                                     y_mask[:, :, :-1],
                                     x_mask,
                                     teacher_forcing=teacher_forcing)

            loss = loss_function(log_probs, validation_batch[1][:, 1:],
                                 model.get_target_embeddings())
            total_item_count += y_mask[:, :, 1:].sum().item()
            total_validation_loss += loss.item()

            y_hat, _ = beam_search(x_e,
                                   x_mask,
                                   model,
                                   get_scores=short_sent_penalty)

            if printed_samples < 4:
                sentence = validation_dataset.fields[0].to_sentence_str(
                    validation_batch[0][-1].tolist())
                reference = validation_dataset.fields[1].to_sentence_str(
                    validation_batch[1][-1].tolist())
                generated = validation_dataset.fields[1].to_sentence_str(
                    y_hat[-1].tolist())
                logger.info('SENTENCE:\n ---- {}'.format(sentence))
                logger.info('REFERENCE:\n ---- {}'.format(reference))
                logger.info('GENERATED:\n ---- {}'.format(generated))

                printed_samples += 1

            update_metric_params(y_hat, validation_batch[1], pad_index,
                                 metrics)

    elapsed_time = time.time() - start_time
    logger.info(
        f'{log_prefix}: '
        f'evaluation_loss={total_validation_loss / total_item_count:.3f}, '
        f'elapsed_time={int(elapsed_time + 0.5)}s')
    for metric_repr in (str(m) for m in metrics):
        logger.info(f'{log_prefix}: evaluation {metric_repr}')

    return metrics[0].get_score()