def multi_test(device, params, test_dataloader, tokenizer, verbose=50): """Test for multilingual translation. Evaluates on all possible translation directions.""" logger = logging.TestLogger(params) logger.make_dirs() train_params = logging.load_params(params.location + '/' + params.name) model = initialiser.initialise_model(train_params, device) model, _, _, _ = logging.load_checkpoint(logger.checkpoint_path, device, model) assert tokenizer is not None add_targets = preprocess.AddTargetTokens(params.langs, tokenizer) pair_accs = {s+'-'+t : 0.0 for s, t in get_pairs(params.langs)} pair_bleus = {} for s, t in get_pairs(params.langs, excluded=params.excluded): _bleu = BLEU() _bleu.set_excluded_indices([0, 2]) pair_bleus[s+'-'+t] = _bleu test_acc = 0.0 start_ = time.time() print(params.__dict__) print("Now testing") for i, data in enumerate(test_dataloader): data = get_directions(data, params.langs, excluded=params.excluded) for direction, (x, y, y_lang) in data.items(): x = add_targets(x, y_lang) bleu = pair_bleus[direction] test_batch_acc = inference_step(x, y, model, logger, tokenizer, device, bleu=bleu, teacher_forcing=params.teacher_forcing, beam_length=params.beam_length) pair_accs[direction] += (test_batch_acc - pair_accs[direction]) / (i + 1) # report the mean accuracy and bleu accross directions if verbose is not None: test_acc += (np.mean([v for v in pair_accs.values()]) - test_acc) / (i + 1) curr_bleu = np.mean([bleu.get_metric() for bleu in pair_bleus.values()]) if i % verbose == 0: print('Batch {} Accuracy {:.4f} Bleu {:.4f} in {:.4f} s per batch'.format( i, test_acc, curr_bleu, (time.time() - start_) / (i + 1))) directions = [d for d in pair_bleus.keys()] test_accs = [pair_accs[d] for d in directions] test_bleus = [pair_bleus[d].get_metric() for d in directions] logger.log_results([directions, test_accs, test_bleus]) logger.dump_examples()
def test(device, params, test_dataloader, tokenizer, verbose=50): """Test loop""" logger = logging.TestLogger(params) logger.make_dirs() train_params = logging.load_params(params.location + '/' + params.name) model = initialiser.initialise_model(train_params, device) model, _, _, _ = logging.load_checkpoint(logger.checkpoint_path, device, model) test_batch_accs = [] bleu = BLEU() bleu.set_excluded_indices([0, 2]) test_acc = 0.0 start_ = time.time() print(params.__dict__) print("Now testing") for i, data in enumerate(test_dataloader): x, y = data test_batch_acc = inference_step(x, y, model, logger, tokenizer, device, bleu=bleu, teacher_forcing=params.teacher_forcing, beam_length=params.beam_length, alpha=params.alpha, beta=params.beta) test_batch_accs.append(test_batch_acc) test_acc += (test_batch_acc - test_acc) / (i + 1) curr_bleu = bleu.get_metric() if verbose is not None: if i % verbose == 0: print('Batch {} Accuracy {:.4f} Bleu {:.4f} in {:.4f} s per batch'.format( i, test_acc, curr_bleu, (time.time() - start_) / (i + 1))) test_bleu = bleu.get_metric() direction = params.langs[0] + '-' + params.langs[1] logger.log_results([direction, test_acc, test_bleu]) logger.dump_examples()
def train(rank, device, logger, params, train_dataloader, val_dataloader=None, tokenizer=None, verbose=50): """Training Loop""" multi = False if len(params.langs) > 2: assert tokenizer is not None multi = True add_targets = preprocess.AddTargetTokens(params.langs, tokenizer) model = initialiser.initialise_model(params, device) optimizer = torch.optim.Adam(model.parameters()) scheduler = WarmupDecay(optimizer, params.warmup_steps, params.d_model, lr_scale=params.lr_scale) criterion = torch.nn.CrossEntropyLoss(reduction='none') _aux_criterion = torch.nn.CosineEmbeddingLoss(reduction='mean') _target = torch.tensor(1.0).to(device) aux_criterion = lambda x, y: _aux_criterion(x, y, _target) epoch = 0 if params.checkpoint: model, optimizer, epoch, scheduler = logging.load_checkpoint( logger.checkpoint_path, device, model, optimizer=optimizer, scheduler=scheduler) if params.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[device.index], find_unused_parameters=True) if rank == 0: if params.wandb: wandb.watch(model) batch_losses, batch_auxs, batch_accs = [], [], [] epoch_losses, epoch_auxs, epoch_accs = [], [], [] val_epoch_losses, val_epoch_accs, val_epoch_bleus = [], [], [] while epoch < params.epochs: start_ = time.time() # train if params.FLAGS: print('training') epoch_loss = 0.0 epoch_aux = 0.0 epoch_acc = 0.0 for i, data in enumerate(train_dataloader): if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction(data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data if params.auxiliary: batch_loss, batch_aux, batch_acc = aux_train_step( x, y, model, criterion, aux_criterion, params.aux_strength, params.frozen_layers, optimizer, scheduler, device, distributed=params.distributed) else: batch_loss, batch_aux, batch_acc = train_step( x, y, model, criterion, aux_criterion, optimizer, scheduler, device, distributed=params.distributed) if rank == 0: batch_loss = batch_loss.item() batch_aux = batch_aux.item() batch_acc = batch_acc.item() batch_losses.append(batch_loss) batch_auxs.append(batch_aux) batch_accs.append(batch_acc) epoch_loss += (batch_loss - epoch_loss) / (i + 1) epoch_aux += (batch_aux - epoch_aux) / (i + 1) epoch_acc += (batch_acc - epoch_acc) / (i + 1) if verbose is not None: if i % verbose == 0: print( 'Batch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} s per batch' .format(i, epoch_loss, epoch_aux, epoch_acc, (time.time() - start_) / (i + 1))) if params.wandb: wandb.log({ 'loss': batch_loss, 'aux_loss': batch_aux, 'accuracy': batch_acc }) if rank == 0: epoch_losses.append(epoch_loss) epoch_auxs.append(epoch_aux) epoch_accs.append(epoch_acc) # val only on rank 0 if rank == 0: if params.FLAGS: print('validating') val_epoch_loss = 0.0 val_epoch_acc = 0.0 val_bleu = 0.0 test_bleu = 0.0 if val_dataloader is not None: bleu = BLEU() bleu.set_excluded_indices([0, 2]) for i, data in enumerate(val_dataloader): if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction( data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data batch_loss, batch_acc = val_step( x, y, model, criterion, bleu, device, distributed=params.distributed) batch_loss = batch_loss.item() batch_acc = batch_acc.item() val_epoch_loss += (batch_loss - val_epoch_loss) / (i + 1) val_epoch_acc += (batch_acc - val_epoch_acc) / (i + 1) val_epoch_losses.append(val_epoch_loss) val_epoch_accs.append(val_epoch_acc) val_bleu = bleu.get_metric() # evaluate without teacher forcing if params.test_freq is not None: if epoch % params.test_freq == 0: bleu_no_tf = BLEU() bleu_no_tf.set_excluded_indices([0, 2]) for i, data in enumerate(val_dataloader): if i > params.test_batches: break else: if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction( data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_tar) # devices x, y, y_tar, enc_mask = to_devices( (x, y, y_tar, enc_mask), device) y_pred = beam_search( x, y, y_tar, model, enc_mask=enc_mask, beam_length=params.beam_length, alpha=params.alpha, beta=params.beta) bleu_no_tf(y_pred, y_tar) test_bleu = bleu_no_tf.get_metric() print(test_bleu) if verbose is not None: print( 'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} Val Loss {:.4f} Val Accuracy {:.4f} Val Bleu {:.4f}' ' Test Bleu {:.4f} in {:.4f} secs \n'.format( epoch, epoch_loss, epoch_aux, epoch_acc, val_epoch_loss, val_epoch_acc, val_bleu, test_bleu, time.time() - start_)) if params.wandb: wandb.log({ 'loss': epoch_loss, 'aux_loss': epoch_aux, 'accuracy': epoch_acc, 'val_loss': val_epoch_loss, 'val_accuracy': val_epoch_acc, 'val_bleu': val_bleu, 'test_bleu': test_bleu }) else: if verbose is not None: print( 'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} secs \n' .format(epoch, epoch_loss, epoch_loss, epoch_acc, time.time() - start_)) if params.wandb: wandb.log({ 'loss': epoch_loss, 'aux_loss': epoch_aux, 'accuracy': epoch_acc }) if params.FLAGS: print('logging results') logger.save_model(epoch, model, optimizer, scheduler=scheduler) logger.log_results([ epoch_loss, epoch_aux, epoch_acc, val_epoch_loss, val_epoch_acc, val_bleu, test_bleu ]) epoch += 1 return epoch_losses, epoch_accs, val_epoch_losses, val_epoch_accs
d_model = 10 dff = 20 layers = 2 heads = 2 max_pe = 1000 vocab_size = 100 dropout = 0.1 location = '.' name = 'test_logging' model = initialise_model(params, device) optimizer = torch.optim.Adam(model.parameters()) scheduler = WarmupDecay(optimizer, params.d_model, 1000) epoch = 10 # test logger logger = logging.TrainLogger(params) logger.make_dirs() logger.save_model(epoch, model, optimizer, scheduler=scheduler) model2 = initialise_model(params, device) optimizer2 = torch.optim.Adam(model2.parameters()) scheduler2 = WarmupDecay(optimizer2, params.d_model, 1000) path = './test_logging/checkpoint/checkpoint' model2, optimizer2, epoch, scheduler = logging.load_checkpoint( path, device, model, optimizer=optimizer, scheduler=scheduler) print(epoch)
def main(params): """ Evaluates a finetuned model on the test or validation dataset.""" # load model and tokenizer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") config = MBartConfig.from_pretrained("facebook/mbart-large-50") model = MBartForConditionalGeneration(config).to(device) checkpoint_location = params.location + '/' + params.name + '/checkpoint/checkpoint' model, _, _, _ = logging.load_checkpoint(checkpoint_location, device, model) def pipeline(dataset, langs, batch_size, max_len): cols = ['input_ids_' + l for l in langs] def tokenize_fn(example): """apply tokenization""" l_tok = [] for lang in langs: encoded = tokenizer.encode(example[lang]) encoded[0] = tokenizer.lang_code_to_id[LANG_CODES[lang]] l_tok.append(encoded) return {'input_ids_' + l: tok for l, tok in zip(langs, l_tok)} def pad_seqs(examples): """Apply padding""" ex_langs = list( zip(*[tuple(ex[col] for col in cols) for ex in examples])) ex_langs = tuple( pad_sequence(x, batch_first=True, max_len=max_len) for x in ex_langs) return ex_langs dataset = filter_languages(dataset, langs) dataset = dataset.map(tokenize_fn) dataset.set_format(type='torch', columns=cols) num_examples = len(dataset) print('-'.join(langs) + ' : {} examples.'.format(num_examples)) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=pad_seqs) return dataloader, num_examples # load data if params.split == 'val': test_dataset = load_dataset('ted_multi', split='validation') elif params.split == 'test': test_dataset = load_dataset('ted_multi', split='test') elif params.split == 'combine': test_dataset = load_dataset('ted_multi', split='validation+test') else: raise NotImplementedError # preprocess splits for each direction test_dataloaders = {} for l1, l2 in combinations(params.langs, 2): test_dataloaders[l1 + '-' + l2], _ = pipeline(test_dataset, [l1, l2], params.batch_size, params.max_len) # evaluate the model def evaluate(x, y, y_code, bleu): y_inp, y_tar = y[:, :-1].contiguous(), y[:, 1:].contiguous() enc_mask = (x != 0) x, y_inp, y_tar, enc_mask = to_devices((x, y_inp, y_tar, enc_mask), device) model.eval() y_pred = model.generate(input_ids=x, decoder_start_token_id=y_code, attention_mask=enc_mask, max_length=x.size(1) + 1, num_beams=params.num_beams, length_penalty=params.length_penalty, early_stopping=True) bleu(y_pred[:, 1:], y_tar) test_results = {} for direction, loader in test_dataloaders.items(): alt_direction = '-'.join(reversed(direction.split('-'))) bleu1, bleu2 = BLEU(), BLEU() bleu1.set_excluded_indices([0, 2]) bleu2.set_excluded_indices([0, 2]) x_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[0]]] y_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-') [-1]]] start_ = time.time() for i, (x, y) in enumerate(loader): if params.test_batches is not None: if i > params.test_batches: break evaluate(x, y, y_code, bleu1) if not params.single_direction: evaluate(y, x, x_code, bleu2) if i % params.verbose == 0: bl1, bl2 = bleu1.get_metric(), bleu2.get_metric() print( 'Batch {} Bleu1 {:.4f} Bleu2 {:.4f} in {:.4f} secs per batch' .format(i, bl1, bl2, (time.time() - start_) / (i + 1))) bl1, bl2 = bleu1.get_metric(), bleu2.get_metric() test_results[direction] = [bl1] test_results[alt_direction] = [bl2] print(direction, bl1, bl2) # save test_results pd.DataFrame(test_results).to_csv(params.location + '/' + params.name + '/test_results.csv', index=False)