Exemplo n.º 1
0
def val_step(x, y, model, criterion, bleu, device, distributed=False):
    # get masks and targets
    y_inp, y_tar = y[:, :-1], y[:, 1:]
    enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
        x, y_inp)

    # devices
    x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask = to_devices(
        (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask),
        device,
        non_blocking=distributed)

    # forward
    model.eval()
    with torch.no_grad():
        y_pred, _ = model(x, y_inp, enc_mask, look_ahead_mask, dec_mask)
        loss = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion)

    # metrics
    batch_loss = loss.detach()
    batch_acc = accuracy_fn(y_pred.detach(), y_tar)

    bleu(torch.argmax(y_pred, axis=-1), y_tar)

    return batch_loss, batch_acc
Exemplo n.º 2
0
def aux_train_step(x,
                   y,
                   model,
                   criterion,
                   aux_criterion,
                   aux_strength,
                   frozen_layers,
                   optimizer,
                   scheduler,
                   device,
                   distributed=False):
    """ Single training step using an auxiliary loss on the encoder outputs."""

    # get masks and targets
    y_inp, y_tar = y[:, :-1], y[:, 1:]
    enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
        x, y_inp)

    # mask for the target language encoded representation.
    enc_mask_aux = base_transformer.create_mask(y_inp)

    # devices
    x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux = to_devices(
        (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux),
        device,
        non_blocking=distributed)

    model.train()
    optimizer.zero_grad()

    x_enc = model.encode(x, enc_mask)
    y_pred = model.final_layer(
        model.decode(y_inp, x_enc, look_ahead_mask, dec_mask)[0])
    y_enc = model.encode(y_inp, enc_mask_aux)

    # main loss.
    loss_main = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion)
    loss_main.backward(retain_graph=True)

    # aux loss
    model = param_freeze(model, frozen_layers)
    loss_aux = auxiliary_loss_fn(x_enc,
                                 y_enc,
                                 aux_criterion,
                                 x_mask=enc_mask,
                                 y_mask=enc_mask_aux)
    scaled_loss_aux = loss_aux * aux_strength
    scaled_loss_aux.backward()

    optimizer.step()
    scheduler.step()
    model = param_freeze(model, frozen_layers, unfreeze=True)

    # metrics
    batch_loss = loss_main.detach()
    batch_aux = loss_aux.detach()
    batch_acc = accuracy_fn(y_pred.detach(), y_tar)

    return batch_loss, batch_aux, batch_acc
Exemplo n.º 3
0
 def evaluate(x, y, y_code, bleu):
     y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
     enc_mask = (x != 0)
     x, y_inp, y_tar, enc_mask = to_devices(
       (x, y_inp, y_tar, enc_mask), device)
     
     model.eval()
     y_pred = model.generate(input_ids=x, decoder_start_token_id=y_code,
         attention_mask=enc_mask, max_length=params.max_len+1,
         num_beams=params.num_beams, length_penalty=params.length_penalty,
         early_stopping=True)
     bleu(y_pred[:,1:], y_tar)
Exemplo n.º 4
0
def train_step(x,
               y,
               model,
               criterion,
               aux_criterion,
               optimizer,
               scheduler,
               device,
               distributed=False):
    # get masks and targets
    y_inp, y_tar = y[:, :-1], y[:, 1:]
    enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
        x, y_inp)

    # mask for the target language encoded representation.
    enc_mask_aux = base_transformer.create_mask(y_inp)

    # devices
    x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux = to_devices(
        (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux),
        device,
        non_blocking=distributed)

    # forward
    model.train()
    x_enc = model.encode(x, enc_mask)
    y_enc = model.encode(y_inp, enc_mask_aux)
    y_pred = model.final_layer(
        model.decode(y_inp, x_enc, look_ahead_mask, dec_mask)[0])
    loss = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion)

    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    with torch.no_grad():
        loss_aux = auxiliary_loss_fn(x_enc,
                                     y_enc,
                                     aux_criterion,
                                     x_mask=enc_mask,
                                     y_mask=enc_mask_aux)

    # metrics
    batch_loss = loss.detach()
    batch_aux = loss_aux
    batch_acc = accuracy_fn(y_pred.detach(), y_tar)

    return batch_loss, batch_aux, batch_acc
Exemplo n.º 5
0
    def evaluate(x, y, y_code, bleu):
        en_code = tokenizer.lang_code_to_id[LANG_CODES['en']]
        y_inp, y_tar = y[:, :-1].contiguous(), y[:, 1:].contiguous()
        enc_mask = (x != 0)
        x, y_inp, y_tar, enc_mask = to_devices(
            (x, y_inp, y_tar, enc_mask), device)

        model.eval()
        pivot_pred = model.generate(input_ids=x, decoder_start_token_id=en_code,
                                    attention_mask=enc_mask, max_length=x.size(1) + 1,
                                    num_beams=params.num_beams, length_penalty=params.length_penalty,
                                    early_stopping=True)
        pivot_pred = mask_after_stop(pivot_pred, 2)
        pivot_mask = (pivot_pred != 0)
        y_pred = model_2.generate(input_ids=pivot_pred, decoder_start_token_id=y_code,
                                  attention_mask=pivot_mask, max_length=x.size(1) + 1,
                                  num_beams=params.num_beams, length_penalty=params.length_penalty,
                                  early_stopping=True)
        bleu(y_pred[:, 1:], y_tar)
Exemplo n.º 6
0
    def train_step(x, y, aux=False):

        y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
        enc_mask, dec_mask = (x != 0), (y_inp != 0)

        x, y_inp, y_tar, enc_mask, dec_mask = to_devices(
          (x, y_inp, y_tar, enc_mask, dec_mask), device)

        model.train()
        if aux: freeze_layers(params.frozen_layers, unfreeze=True)
        output = model(input_ids=x, decoder_input_ids=y_inp,
                   labels=y_tar, attention_mask=enc_mask,
                   decoder_attention_mask=dec_mask)
        optimizer.zero_grad()
        loss = loss_fn(output, y_tar)
        loss.backward(retain_graph=aux)

        if aux: freeze_layers(params.frozen_layers)
        torch.set_grad_enabled(aux)

        x_enc = output.encoder_last_hidden_state
        y_enc = model.model.encoder(y_inp, attention_mask=dec_mask)['last_hidden_state']
        x_enc = torch.max(x_enc + -999 * (1-enc_mask.type(x_enc.dtype)).unsqueeze(-1), dim=1)[0]
        y_enc = torch.max(y_enc + -999 * (1-dec_mask.type(y_enc.dtype)).unsqueeze(-1), dim=1)[0]
        aux_loss = F.cosine_embedding_loss(x_enc, y_enc, _target)
        scaled_aux_loss = params.aux_strength * aux_loss
        
        torch.set_grad_enabled(True)
        if aux: scaled_aux_loss.backward()

        optimizer.step()
        scheduler.step()

        accuracy = accuracy_fn(output.logits, y_tar)

        return loss.item(), aux_loss.item(), accuracy.item()
Exemplo n.º 7
0
def train(rank,
          device,
          logger,
          params,
          train_dataloader,
          val_dataloader=None,
          tokenizer=None,
          verbose=50):
    """Training Loop"""

    multi = False
    if len(params.langs) > 2:
        assert tokenizer is not None
        multi = True
        add_targets = preprocess.AddTargetTokens(params.langs, tokenizer)

    model = initialiser.initialise_model(params, device)
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = WarmupDecay(optimizer,
                            params.warmup_steps,
                            params.d_model,
                            lr_scale=params.lr_scale)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    _aux_criterion = torch.nn.CosineEmbeddingLoss(reduction='mean')
    _target = torch.tensor(1.0).to(device)
    aux_criterion = lambda x, y: _aux_criterion(x, y, _target)

    epoch = 0
    if params.checkpoint:
        model, optimizer, epoch, scheduler = logging.load_checkpoint(
            logger.checkpoint_path,
            device,
            model,
            optimizer=optimizer,
            scheduler=scheduler)

    if params.distributed:
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[device.index], find_unused_parameters=True)

    if rank == 0:
        if params.wandb:
            wandb.watch(model)
        batch_losses, batch_auxs, batch_accs = [], [], []
        epoch_losses, epoch_auxs, epoch_accs = [], [], []
        val_epoch_losses, val_epoch_accs, val_epoch_bleus = [], [], []

    while epoch < params.epochs:
        start_ = time.time()

        # train
        if params.FLAGS:
            print('training')
        epoch_loss = 0.0
        epoch_aux = 0.0
        epoch_acc = 0.0
        for i, data in enumerate(train_dataloader):

            if multi:
                # sample a tranlsation direction and add target tokens
                (x, y), (x_lang,
                         y_lang) = sample_direction(data,
                                                    params.langs,
                                                    excluded=params.excluded)
                x = add_targets(x, y_lang)
            else:
                x, y = data

            if params.auxiliary:
                batch_loss, batch_aux, batch_acc = aux_train_step(
                    x,
                    y,
                    model,
                    criterion,
                    aux_criterion,
                    params.aux_strength,
                    params.frozen_layers,
                    optimizer,
                    scheduler,
                    device,
                    distributed=params.distributed)
            else:
                batch_loss, batch_aux, batch_acc = train_step(
                    x,
                    y,
                    model,
                    criterion,
                    aux_criterion,
                    optimizer,
                    scheduler,
                    device,
                    distributed=params.distributed)

            if rank == 0:
                batch_loss = batch_loss.item()
                batch_aux = batch_aux.item()
                batch_acc = batch_acc.item()
                batch_losses.append(batch_loss)
                batch_auxs.append(batch_aux)
                batch_accs.append(batch_acc)
                epoch_loss += (batch_loss - epoch_loss) / (i + 1)
                epoch_aux += (batch_aux - epoch_aux) / (i + 1)
                epoch_acc += (batch_acc - epoch_acc) / (i + 1)

                if verbose is not None:
                    if i % verbose == 0:
                        print(
                            'Batch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} s per batch'
                            .format(i, epoch_loss, epoch_aux, epoch_acc,
                                    (time.time() - start_) / (i + 1)))
                if params.wandb:
                    wandb.log({
                        'loss': batch_loss,
                        'aux_loss': batch_aux,
                        'accuracy': batch_acc
                    })

        if rank == 0:
            epoch_losses.append(epoch_loss)
            epoch_auxs.append(epoch_aux)
            epoch_accs.append(epoch_acc)

        # val only on rank 0
        if rank == 0:
            if params.FLAGS:
                print('validating')
            val_epoch_loss = 0.0
            val_epoch_acc = 0.0
            val_bleu = 0.0
            test_bleu = 0.0
            if val_dataloader is not None:
                bleu = BLEU()
                bleu.set_excluded_indices([0, 2])
                for i, data in enumerate(val_dataloader):

                    if multi:
                        # sample a tranlsation direction and add target tokens
                        (x, y), (x_lang, y_lang) = sample_direction(
                            data, params.langs, excluded=params.excluded)
                        x = add_targets(x, y_lang)
                    else:
                        x, y = data

                    batch_loss, batch_acc = val_step(
                        x,
                        y,
                        model,
                        criterion,
                        bleu,
                        device,
                        distributed=params.distributed)

                    batch_loss = batch_loss.item()
                    batch_acc = batch_acc.item()
                    val_epoch_loss += (batch_loss - val_epoch_loss) / (i + 1)
                    val_epoch_acc += (batch_acc - val_epoch_acc) / (i + 1)

                val_epoch_losses.append(val_epoch_loss)
                val_epoch_accs.append(val_epoch_acc)
                val_bleu = bleu.get_metric()

                # evaluate without teacher forcing
                if params.test_freq is not None:
                    if epoch % params.test_freq == 0:
                        bleu_no_tf = BLEU()
                        bleu_no_tf.set_excluded_indices([0, 2])
                        for i, data in enumerate(val_dataloader):
                            if i > params.test_batches:
                                break
                            else:
                                if multi:
                                    # sample a tranlsation direction and add target tokens
                                    (x, y), (x_lang,
                                             y_lang) = sample_direction(
                                                 data,
                                                 params.langs,
                                                 excluded=params.excluded)
                                    x = add_targets(x, y_lang)
                                else:
                                    x, y = data

                                y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:]
                                enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
                                    x, y_tar)

                                # devices
                                x, y, y_tar, enc_mask = to_devices(
                                    (x, y, y_tar, enc_mask), device)

                                y_pred = beam_search(
                                    x,
                                    y,
                                    y_tar,
                                    model,
                                    enc_mask=enc_mask,
                                    beam_length=params.beam_length,
                                    alpha=params.alpha,
                                    beta=params.beta)
                                bleu_no_tf(y_pred, y_tar)

                        test_bleu = bleu_no_tf.get_metric()
                        print(test_bleu)

                if verbose is not None:
                    print(
                        'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} Val Loss {:.4f} Val Accuracy {:.4f} Val Bleu {:.4f}'
                        ' Test Bleu {:.4f} in {:.4f} secs \n'.format(
                            epoch, epoch_loss, epoch_aux, epoch_acc,
                            val_epoch_loss, val_epoch_acc, val_bleu, test_bleu,
                            time.time() - start_))
                if params.wandb:
                    wandb.log({
                        'loss': epoch_loss,
                        'aux_loss': epoch_aux,
                        'accuracy': epoch_acc,
                        'val_loss': val_epoch_loss,
                        'val_accuracy': val_epoch_acc,
                        'val_bleu': val_bleu,
                        'test_bleu': test_bleu
                    })
            else:
                if verbose is not None:
                    print(
                        'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} secs \n'
                        .format(epoch, epoch_loss, epoch_loss, epoch_acc,
                                time.time() - start_))
                if params.wandb:
                    wandb.log({
                        'loss': epoch_loss,
                        'aux_loss': epoch_aux,
                        'accuracy': epoch_acc
                    })

            if params.FLAGS:
                print('logging results')
            logger.save_model(epoch, model, optimizer, scheduler=scheduler)
            logger.log_results([
                epoch_loss, epoch_aux, epoch_acc, val_epoch_loss,
                val_epoch_acc, val_bleu, test_bleu
            ])

        epoch += 1

    return epoch_losses, epoch_accs, val_epoch_losses, val_epoch_accs
Exemplo n.º 8
0
def inference_step(x,
                   y,
                   model,
                   logger,
                   tokenizer,
                   device,
                   bleu=None,
                   teacher_forcing=False,
                   pivot_mode=False,
                   beam_length=1,
                   alpha=0.0,
                   beta=0.0):
    """
    inference step.
    x: source language
    y: target language
    """
    if teacher_forcing:
        y_inp, y_tar = y[:, :-1], y[:, 1:]
        enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
            x, y_inp)

        # devices
        x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask = to_devices(
            (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask), device)

        # inference
        model.eval()
        with torch.no_grad():
            y_pred, _ = model(x, y_inp, enc_mask, look_ahead_mask, dec_mask)

        if not pivot_mode:
            batch_acc = accuracy_fn(y_pred.detach(), y_tar).cpu().item()
            bleu(torch.argmax(y_pred, axis=-1), y_tar)
            logger.log_examples(x, y_tar, torch.argmax(y_pred, axis=-1),
                                tokenizer)
            return batch_acc
        else:
            return torch.argmax(y_pred, axis=-1)

    else:
        # Retrieve the start of sequence token and the target translation
        y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:]
        enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
            x, y_tar)

        # devices
        x, y, y_tar, enc_mask = to_devices((x, y, y_tar, enc_mask), device)

        # inference
        model.eval()
        if beam_length == 1:
            y_pred = greedy_search(x, y, y_tar, model, enc_mask=enc_mask)
        else:
            y_pred = beam_search(x,
                                 y,
                                 y_tar,
                                 model,
                                 enc_mask=enc_mask,
                                 beam_length=beam_length,
                                 alpha=alpha,
                                 beta=beta)

        if not pivot_mode:
            batch_acc = 0
            if bleu is not None:
                bleu(y_pred, y_tar)
            logger.log_examples(x, y_tar, y_pred, tokenizer)
            return batch_acc
        else:
            return y_pred