예제 #1
0
파일: test.py 프로젝트: IanYHWu/NLP_project
def multi_test(device, params, test_dataloader, tokenizer, verbose=50):
    """Test for multilingual translation. Evaluates on all possible translation directions."""

    logger = logging.TestLogger(params)
    logger.make_dirs()
    train_params = logging.load_params(params.location + '/' + params.name)

    model = initialiser.initialise_model(train_params, device)
    model, _, _, _ = logging.load_checkpoint(logger.checkpoint_path, device, model)

    assert tokenizer is not None
    add_targets = preprocess.AddTargetTokens(params.langs, tokenizer)
    pair_accs = {s+'-'+t : 0.0 for s, t in get_pairs(params.langs)}
    pair_bleus = {}
    for s, t in get_pairs(params.langs, excluded=params.excluded):
        _bleu = BLEU()
        _bleu.set_excluded_indices([0, 2])
        pair_bleus[s+'-'+t] = _bleu

    test_acc = 0.0
    start_ = time.time()

    print(params.__dict__)
    print("Now testing")
    for i, data in enumerate(test_dataloader):

        data = get_directions(data, params.langs, excluded=params.excluded)
        for direction, (x, y, y_lang) in data.items():
            x = add_targets(x, y_lang)
            bleu = pair_bleus[direction]
            test_batch_acc = inference_step(x, y, model, logger, tokenizer, device, bleu=bleu,
                                            teacher_forcing=params.teacher_forcing,
                                            beam_length=params.beam_length)
            pair_accs[direction] += (test_batch_acc - pair_accs[direction]) / (i + 1)

        # report the mean accuracy and bleu accross directions
        if verbose is not None:
            test_acc += (np.mean([v for v in pair_accs.values()]) - test_acc) / (i + 1)
            curr_bleu = np.mean([bleu.get_metric() for bleu in pair_bleus.values()])
            if i % verbose == 0:
                print('Batch {} Accuracy {:.4f} Bleu {:.4f} in {:.4f} s per batch'.format(
                    i, test_acc, curr_bleu, (time.time() - start_) / (i + 1)))

    directions = [d for d in pair_bleus.keys()]
    test_accs = [pair_accs[d] for d in directions]
    test_bleus = [pair_bleus[d].get_metric() for d in directions]
    logger.log_results([directions, test_accs, test_bleus])
    logger.dump_examples()
예제 #2
0
파일: test.py 프로젝트: IanYHWu/NLP_project
def test(device, params, test_dataloader, tokenizer, verbose=50):
    """Test loop"""

    logger = logging.TestLogger(params)
    logger.make_dirs()
    train_params = logging.load_params(params.location + '/' + params.name)

    model = initialiser.initialise_model(train_params, device)
    model, _, _, _ = logging.load_checkpoint(logger.checkpoint_path, device, model)

    test_batch_accs = []
    bleu = BLEU()
    bleu.set_excluded_indices([0, 2])

    test_acc = 0.0
    start_ = time.time()

    print(params.__dict__)
    print("Now testing")
    for i, data in enumerate(test_dataloader):

        x, y = data
        test_batch_acc = inference_step(x, y, model, logger, tokenizer, device, bleu=bleu,
                                        teacher_forcing=params.teacher_forcing,
                                        beam_length=params.beam_length,
                                        alpha=params.alpha, beta=params.beta)
        test_batch_accs.append(test_batch_acc)

        test_acc += (test_batch_acc - test_acc) / (i + 1)
        curr_bleu = bleu.get_metric()

        if verbose is not None:
            if i % verbose == 0:
                print('Batch {} Accuracy {:.4f} Bleu {:.4f} in {:.4f} s per batch'.format(
                    i, test_acc, curr_bleu, (time.time() - start_) / (i + 1)))

    test_bleu = bleu.get_metric()
    direction = params.langs[0] + '-' + params.langs[1]
    logger.log_results([direction, test_acc, test_bleu])
    logger.dump_examples()
예제 #3
0
def train(rank,
          device,
          logger,
          params,
          train_dataloader,
          val_dataloader=None,
          tokenizer=None,
          verbose=50):
    """Training Loop"""

    multi = False
    if len(params.langs) > 2:
        assert tokenizer is not None
        multi = True
        add_targets = preprocess.AddTargetTokens(params.langs, tokenizer)

    model = initialiser.initialise_model(params, device)
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = WarmupDecay(optimizer,
                            params.warmup_steps,
                            params.d_model,
                            lr_scale=params.lr_scale)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    _aux_criterion = torch.nn.CosineEmbeddingLoss(reduction='mean')
    _target = torch.tensor(1.0).to(device)
    aux_criterion = lambda x, y: _aux_criterion(x, y, _target)

    epoch = 0
    if params.checkpoint:
        model, optimizer, epoch, scheduler = logging.load_checkpoint(
            logger.checkpoint_path,
            device,
            model,
            optimizer=optimizer,
            scheduler=scheduler)

    if params.distributed:
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[device.index], find_unused_parameters=True)

    if rank == 0:
        if params.wandb:
            wandb.watch(model)
        batch_losses, batch_auxs, batch_accs = [], [], []
        epoch_losses, epoch_auxs, epoch_accs = [], [], []
        val_epoch_losses, val_epoch_accs, val_epoch_bleus = [], [], []

    while epoch < params.epochs:
        start_ = time.time()

        # train
        if params.FLAGS:
            print('training')
        epoch_loss = 0.0
        epoch_aux = 0.0
        epoch_acc = 0.0
        for i, data in enumerate(train_dataloader):

            if multi:
                # sample a tranlsation direction and add target tokens
                (x, y), (x_lang,
                         y_lang) = sample_direction(data,
                                                    params.langs,
                                                    excluded=params.excluded)
                x = add_targets(x, y_lang)
            else:
                x, y = data

            if params.auxiliary:
                batch_loss, batch_aux, batch_acc = aux_train_step(
                    x,
                    y,
                    model,
                    criterion,
                    aux_criterion,
                    params.aux_strength,
                    params.frozen_layers,
                    optimizer,
                    scheduler,
                    device,
                    distributed=params.distributed)
            else:
                batch_loss, batch_aux, batch_acc = train_step(
                    x,
                    y,
                    model,
                    criterion,
                    aux_criterion,
                    optimizer,
                    scheduler,
                    device,
                    distributed=params.distributed)

            if rank == 0:
                batch_loss = batch_loss.item()
                batch_aux = batch_aux.item()
                batch_acc = batch_acc.item()
                batch_losses.append(batch_loss)
                batch_auxs.append(batch_aux)
                batch_accs.append(batch_acc)
                epoch_loss += (batch_loss - epoch_loss) / (i + 1)
                epoch_aux += (batch_aux - epoch_aux) / (i + 1)
                epoch_acc += (batch_acc - epoch_acc) / (i + 1)

                if verbose is not None:
                    if i % verbose == 0:
                        print(
                            'Batch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} s per batch'
                            .format(i, epoch_loss, epoch_aux, epoch_acc,
                                    (time.time() - start_) / (i + 1)))
                if params.wandb:
                    wandb.log({
                        'loss': batch_loss,
                        'aux_loss': batch_aux,
                        'accuracy': batch_acc
                    })

        if rank == 0:
            epoch_losses.append(epoch_loss)
            epoch_auxs.append(epoch_aux)
            epoch_accs.append(epoch_acc)

        # val only on rank 0
        if rank == 0:
            if params.FLAGS:
                print('validating')
            val_epoch_loss = 0.0
            val_epoch_acc = 0.0
            val_bleu = 0.0
            test_bleu = 0.0
            if val_dataloader is not None:
                bleu = BLEU()
                bleu.set_excluded_indices([0, 2])
                for i, data in enumerate(val_dataloader):

                    if multi:
                        # sample a tranlsation direction and add target tokens
                        (x, y), (x_lang, y_lang) = sample_direction(
                            data, params.langs, excluded=params.excluded)
                        x = add_targets(x, y_lang)
                    else:
                        x, y = data

                    batch_loss, batch_acc = val_step(
                        x,
                        y,
                        model,
                        criterion,
                        bleu,
                        device,
                        distributed=params.distributed)

                    batch_loss = batch_loss.item()
                    batch_acc = batch_acc.item()
                    val_epoch_loss += (batch_loss - val_epoch_loss) / (i + 1)
                    val_epoch_acc += (batch_acc - val_epoch_acc) / (i + 1)

                val_epoch_losses.append(val_epoch_loss)
                val_epoch_accs.append(val_epoch_acc)
                val_bleu = bleu.get_metric()

                # evaluate without teacher forcing
                if params.test_freq is not None:
                    if epoch % params.test_freq == 0:
                        bleu_no_tf = BLEU()
                        bleu_no_tf.set_excluded_indices([0, 2])
                        for i, data in enumerate(val_dataloader):
                            if i > params.test_batches:
                                break
                            else:
                                if multi:
                                    # sample a tranlsation direction and add target tokens
                                    (x, y), (x_lang,
                                             y_lang) = sample_direction(
                                                 data,
                                                 params.langs,
                                                 excluded=params.excluded)
                                    x = add_targets(x, y_lang)
                                else:
                                    x, y = data

                                y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:]
                                enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks(
                                    x, y_tar)

                                # devices
                                x, y, y_tar, enc_mask = to_devices(
                                    (x, y, y_tar, enc_mask), device)

                                y_pred = beam_search(
                                    x,
                                    y,
                                    y_tar,
                                    model,
                                    enc_mask=enc_mask,
                                    beam_length=params.beam_length,
                                    alpha=params.alpha,
                                    beta=params.beta)
                                bleu_no_tf(y_pred, y_tar)

                        test_bleu = bleu_no_tf.get_metric()
                        print(test_bleu)

                if verbose is not None:
                    print(
                        'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} Val Loss {:.4f} Val Accuracy {:.4f} Val Bleu {:.4f}'
                        ' Test Bleu {:.4f} in {:.4f} secs \n'.format(
                            epoch, epoch_loss, epoch_aux, epoch_acc,
                            val_epoch_loss, val_epoch_acc, val_bleu, test_bleu,
                            time.time() - start_))
                if params.wandb:
                    wandb.log({
                        'loss': epoch_loss,
                        'aux_loss': epoch_aux,
                        'accuracy': epoch_acc,
                        'val_loss': val_epoch_loss,
                        'val_accuracy': val_epoch_acc,
                        'val_bleu': val_bleu,
                        'test_bleu': test_bleu
                    })
            else:
                if verbose is not None:
                    print(
                        'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} secs \n'
                        .format(epoch, epoch_loss, epoch_loss, epoch_acc,
                                time.time() - start_))
                if params.wandb:
                    wandb.log({
                        'loss': epoch_loss,
                        'aux_loss': epoch_aux,
                        'accuracy': epoch_acc
                    })

            if params.FLAGS:
                print('logging results')
            logger.save_model(epoch, model, optimizer, scheduler=scheduler)
            logger.log_results([
                epoch_loss, epoch_aux, epoch_acc, val_epoch_loss,
                val_epoch_acc, val_bleu, test_bleu
            ])

        epoch += 1

    return epoch_losses, epoch_accs, val_epoch_losses, val_epoch_accs
예제 #4
0
    d_model = 10
    dff = 20
    layers = 2
    heads = 2
    max_pe = 1000
    vocab_size = 100
    dropout = 0.1
    location = '.'
    name = 'test_logging'


model = initialise_model(params, device)
optimizer = torch.optim.Adam(model.parameters())
scheduler = WarmupDecay(optimizer, params.d_model, 1000)
epoch = 10

# test logger
logger = logging.TrainLogger(params)
logger.make_dirs()
logger.save_model(epoch, model, optimizer, scheduler=scheduler)

model2 = initialise_model(params, device)
optimizer2 = torch.optim.Adam(model2.parameters())
scheduler2 = WarmupDecay(optimizer2, params.d_model, 1000)

path = './test_logging/checkpoint/checkpoint'
model2, optimizer2, epoch, scheduler = logging.load_checkpoint(
    path, device, model, optimizer=optimizer, scheduler=scheduler)

print(epoch)
예제 #5
0
def main(params):
    """ Evaluates a finetuned model on the test or validation dataset."""

    # load model and tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
    config = MBartConfig.from_pretrained("facebook/mbart-large-50")
    model = MBartForConditionalGeneration(config).to(device)
    checkpoint_location = params.location + '/' + params.name + '/checkpoint/checkpoint'
    model, _, _, _ = logging.load_checkpoint(checkpoint_location, device,
                                             model)

    def pipeline(dataset, langs, batch_size, max_len):

        cols = ['input_ids_' + l for l in langs]

        def tokenize_fn(example):
            """apply tokenization"""
            l_tok = []
            for lang in langs:
                encoded = tokenizer.encode(example[lang])
                encoded[0] = tokenizer.lang_code_to_id[LANG_CODES[lang]]
                l_tok.append(encoded)
            return {'input_ids_' + l: tok for l, tok in zip(langs, l_tok)}

        def pad_seqs(examples):
            """Apply padding"""
            ex_langs = list(
                zip(*[tuple(ex[col] for col in cols) for ex in examples]))
            ex_langs = tuple(
                pad_sequence(x, batch_first=True, max_len=max_len)
                for x in ex_langs)
            return ex_langs

        dataset = filter_languages(dataset, langs)
        dataset = dataset.map(tokenize_fn)
        dataset.set_format(type='torch', columns=cols)
        num_examples = len(dataset)
        print('-'.join(langs) + ' : {} examples.'.format(num_examples))
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 collate_fn=pad_seqs)
        return dataloader, num_examples

    # load data
    if params.split == 'val':
        test_dataset = load_dataset('ted_multi', split='validation')
    elif params.split == 'test':
        test_dataset = load_dataset('ted_multi', split='test')
    elif params.split == 'combine':
        test_dataset = load_dataset('ted_multi', split='validation+test')
    else:
        raise NotImplementedError

    # preprocess splits for each direction
    test_dataloaders = {}
    for l1, l2 in combinations(params.langs, 2):
        test_dataloaders[l1 + '-' + l2], _ = pipeline(test_dataset, [l1, l2],
                                                      params.batch_size,
                                                      params.max_len)

    # evaluate the model
    def evaluate(x, y, y_code, bleu):
        y_inp, y_tar = y[:, :-1].contiguous(), y[:, 1:].contiguous()
        enc_mask = (x != 0)
        x, y_inp, y_tar, enc_mask = to_devices((x, y_inp, y_tar, enc_mask),
                                               device)

        model.eval()
        y_pred = model.generate(input_ids=x,
                                decoder_start_token_id=y_code,
                                attention_mask=enc_mask,
                                max_length=x.size(1) + 1,
                                num_beams=params.num_beams,
                                length_penalty=params.length_penalty,
                                early_stopping=True)
        bleu(y_pred[:, 1:], y_tar)

    test_results = {}
    for direction, loader in test_dataloaders.items():
        alt_direction = '-'.join(reversed(direction.split('-')))
        bleu1, bleu2 = BLEU(), BLEU()
        bleu1.set_excluded_indices([0, 2])
        bleu2.set_excluded_indices([0, 2])
        x_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[0]]]
        y_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')
                                                      [-1]]]

        start_ = time.time()
        for i, (x, y) in enumerate(loader):
            if params.test_batches is not None:
                if i > params.test_batches:
                    break

            evaluate(x, y, y_code, bleu1)
            if not params.single_direction:
                evaluate(y, x, x_code, bleu2)
            if i % params.verbose == 0:
                bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
                print(
                    'Batch {} Bleu1 {:.4f} Bleu2 {:.4f} in {:.4f} secs per batch'
                    .format(i, bl1, bl2, (time.time() - start_) / (i + 1)))

        bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
        test_results[direction] = [bl1]
        test_results[alt_direction] = [bl2]
        print(direction, bl1, bl2)

    # save test_results
    pd.DataFrame(test_results).to_csv(params.location + '/' + params.name +
                                      '/test_results.csv',
                                      index=False)