def log_progress(epoch_i,
                 start_time,
                 tr_loss,
                 val_loss,
                 translations=None,
                 tb_writer=None):
    metrics = {
        "train": {
            "loss": tr_loss,
            "ppl": math.exp(tr_loss),
        },
        "val": {
            "loss": val_loss,
            "ppl": math.exp(val_loss),
        },
    }

    # Get additional metrics
    if translations:
        src_dec_all, hyp_dec_all, ref_dec_all = translations
        m_bleu_score = bleu_score([x.split(" ") for x in hyp_dec_all],
                                  [[x.split(" ")] for x in ref_dec_all])
        metrics["val"]["bleu"] = m_bleu_score * 100

        # Print translations
        helpers.print_translations(hyp_dec_all,
                                   ref_dec_all,
                                   src_dec_all,
                                   limit=50)

    # Print stuff
    end_time = time.time()
    epoch_hours, epoch_mins, epoch_secs = helpers.epoch_time(
        start_time, end_time)
    print("------------------------------------------------------------")
    print(f'Epoch: {epoch_i + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(
        f'\t- Train Loss: {metrics["train"]["loss"]:.3f} | Train PPL: {metrics["train"]["ppl"]:.3f}'
    )
    print(
        f'\t- Val Loss: {metrics["val"]["loss"]:.3f} | Val PPL: {metrics["val"]["ppl"]:.3f} | Val BLEU: {metrics["val"]["bleu"]:.3f}'
    )
    print("------------------------------------------------------------")

    # Tensorboard
    if tb_writer:
        for split in ["train", "val"]:
            for k, v in metrics[split].items():
                tb_writer.add_scalar(f'{split}_{k.lower()}', v, epoch_i + 1)
                wandb.log({f'{split}_{k.lower()}': v})

    return metrics
Esempio n. 2
0
def run_experiment(datapath, src, trg, alpha, domain=None):
    start_time = time.time()
    experiment_name = f"{domain}_a{alpha}_{src}-{trg}_small_1gpu"
    model_name = f"{MODEL_NAME}_{domain}_a{alpha}"

    ###########################################################################
    ###########################################################################

    wandb.init(project=WANDB_PROJECT,
               entity='salvacarrion',
               name=experiment_name,
               reinit=True)
    config = wandb.config
    config.model_name = MODEL_NAME
    config.domain = domain
    config.max_epochs = MAX_EPOCHS
    config.learning_rate = LEARNING_RATE
    config.batch_size = BATCH_SIZE
    config.max_tokens = MAX_TOKENS
    config.warmup_updates = WARMUP_UPDATES
    config.patience = PATIENCE
    config.acc_gradients = ACC_GRADIENTS
    config.weight_decay = WEIGHT_DECAY
    config.clip_gradients = CLIP_GRADIENTS
    config.multigpu = MULTIGPU
    config.device1 = str(DEVICE1)
    config.device2 = str(DEVICE2)
    config.num_workers = NUM_WORKERS
    config.tok_model = TOK_MODEL
    config.tok_size = TOK_SIZE
    config.tok_folder = TOK_FOLDER
    config.lowercase = LOWERCASE
    config.truncate = TRUNCATE
    config.max_length_truncate = MAX_LENGTH_TRUNC
    config.sampler_name = str(SAMPLER_NAME)
    config.start_from_checkpoint1 = START_FROM_CHECKPOINT_MODEL1
    config.start_from_checkpoint2 = START_FROM_CHECKPOINT_MODEL2
    config.model_interpolation = alpha
    print(config)

    ###########################################################################
    ###########################################################################

    # Load tokenizers
    src_tok, trg_tok = helpers.get_tokenizers(os.path.join(
        datapath, DATASET_TOK_NAME, TOK_FOLDER),
                                              src,
                                              trg,
                                              tok_model=TOK_MODEL,
                                              lower=LOWERCASE,
                                              truncation=TRUNCATE,
                                              max_length=MAX_LENGTH_TRUNC)

    # Load dataset
    datapath_clean = DATASET_CLEAN_NAME
    if TOK_MODEL == "bpe":  # Do not preprocess again when using bpe
        src_tok.apply_bpe = False
        trg_tok.apply_bpe = False
        datapath_clean = os.path.join(DATASET_TOK_NAME, TOK_FOLDER)

    # Get datasets
    train_ds = TranslationDataset(os.path.join(datapath, datapath_clean),
                                  src_tok, trg_tok, "train")
    val_ds_olddomain = TranslationDataset(
        os.path.join(os.path.join(DATASETS_PATH, "health_es-en"),
                     datapath_clean), src_tok, trg_tok, "val")
    val_ds_newdomain = TranslationDataset(
        os.path.join(datapath, datapath_clean), src_tok, trg_tok, "val")

    # Get dataloaders
    train_loader = base.get_data_loader(SAMPLER_NAME,
                                        train_ds,
                                        BATCH_SIZE,
                                        MAX_TOKENS,
                                        NUM_WORKERS,
                                        shuffle=True)
    val_loader_olddomain = base.get_data_loader(SAMPLER_NAME,
                                                val_ds_olddomain,
                                                BATCH_SIZE,
                                                MAX_TOKENS,
                                                NUM_WORKERS,
                                                shuffle=False)
    val_loader_newdomain = base.get_data_loader(SAMPLER_NAME,
                                                val_ds_newdomain,
                                                BATCH_SIZE,
                                                MAX_TOKENS,
                                                NUM_WORKERS,
                                                shuffle=False)

    # Instantiate model #1
    model1 = TransformerDyn(d_model=256,
                            enc_layers=3,
                            dec_layers=3,
                            enc_heads=8,
                            dec_heads=8,
                            enc_dff_dim=512,
                            dec_dff_dim=512,
                            enc_dropout=0.1,
                            dec_dropout=0.1,
                            max_src_len=2000,
                            max_trg_len=2000,
                            src_tok=src_tok,
                            trg_tok=trg_tok,
                            static_pos_emb=True).to(DEVICE1)
    print(
        f'The model #1 has {model1.count_parameters():,} trainable parameters')
    model1.apply(base.initialize_weights)

    # Load weights
    if START_FROM_CHECKPOINT_MODEL1:
        from_checkpoint_path = os.path.join(datapath, DATASET_CHECKPOINT_NAME,
                                            START_FROM_CHECKPOINT_MODEL1)
        print(f"(Model 1) Loading weights from: {from_checkpoint_path}")
        model1.load_state_dict(torch.load(from_checkpoint_path))

    model2 = Transformer(d_model=256,
                         enc_layers=3,
                         dec_layers=3,
                         enc_heads=8,
                         dec_heads=8,
                         enc_dff_dim=512,
                         dec_dff_dim=512,
                         enc_dropout=0.1,
                         dec_dropout=0.1,
                         max_src_len=2000,
                         max_trg_len=2000,
                         src_tok=src_tok,
                         trg_tok=trg_tok,
                         static_pos_emb=True).to(DEVICE2)
    print(
        f'The model #2 has {model2.count_parameters():,} trainable parameters')
    model2.apply(base.initialize_weights)

    # [MODEL1] Freeze embedding layers and share parameters
    for param in model1.parameters():
        param.requires_grad = False
    model1.encoder_shared = model2.encoder
    model1.decoder_shared = model2.decoder

    # Load weights
    if START_FROM_CHECKPOINT_MODEL2:
        from_checkpoint_path = os.path.join(datapath, DATASET_CHECKPOINT_NAME,
                                            START_FROM_CHECKPOINT_MODEL2)
        print(f"(Model 2) Loading weights from: {from_checkpoint_path}")
        model2.load_state_dict(torch.load(from_checkpoint_path))

    optimizer = torch.optim.Adam(model2.parameters(), lr=LEARNING_RATE)
    cross_entropy_loss1 = nn.CrossEntropyLoss(
        ignore_index=trg_tok.word2idx[trg_tok.PAD_WORD])
    cross_entropy_loss2 = nn.CrossEntropyLoss(
        ignore_index=trg_tok.word2idx[trg_tok.PAD_WORD])
    criterion = CustomLoss(alpha, cross_entropy_loss1, cross_entropy_loss2)

    # Tensorboard (it needs some epochs to start working ~10-20)
    tb_writer = SummaryWriter(
        os.path.join(datapath, DATASET_LOGS_NAME, f"{model_name}"))
    wandb.watch(model2)

    # Train and validate model
    fit(model1,
        model2,
        optimizer,
        train_loader=train_loader,
        val_loader_olddomain=val_loader_olddomain,
        val_loader_newdomain=val_loader_newdomain,
        epochs=MAX_EPOCHS,
        criterion=criterion,
        checkpoint_path=os.path.join(datapath, DATASET_CHECKPOINT_NAME,
                                     model_name),
        tb_writer=tb_writer)

    print("************************************************************")
    epoch_hours, epoch_mins, epoch_secs = helpers.epoch_time(
        start_time, end_time=time.time())
    print(f'Time experiment: {epoch_hours}h {epoch_mins}m {epoch_secs}s')
    print("************************************************************")
    print("Done!")
def evaluate_hbm(model, criterion, src_tok, trg_tok, train_domain, basepath,
                 datapath_clean, start_time):
    # Get all folders in the root path
    test_datasets = [
        os.path.join(DATASETS_PATH, TOK_FOLDER, x) for x in [
            f"health_{src_tok.lang}-{trg_tok.lang}",
            f"biological_{src_tok.lang}-{trg_tok.lang}",
            f"merged_{src_tok.lang}-{trg_tok.lang}"
        ]
    ]
    for test_dataset in test_datasets:
        test_domain, (test_src, test_trg) = utils.get_dataset_ids(test_dataset)
        print("#############################################")
        print(
            f"=> TESTING MODEL FROM '{train_domain}' IN DOMAIN '{test_domain}'"
        )

        # Get datasets
        test_ds = TranslationDataset(
            os.path.join(test_dataset, datapath_clean), src_tok, trg_tok,
            "test")

        # Get dataloaders
        test_loader = base.get_data_loader(SAMPLER_NAME,
                                           test_ds,
                                           BATCH_SIZE,
                                           MAX_TOKENS,
                                           NUM_WORKERS,
                                           shuffle=False)

        # # Evaluate
        start_time2 = time.time()
        val_loss, val_translations = base.evaluate(model,
                                                   test_loader,
                                                   criterion,
                                                   device=DEVICE1)

        # Log progress
        metrics = base.log_progress(epoch_i=0,
                                    start_time=start_time2,
                                    tr_loss=None,
                                    val_loss=val_loss,
                                    tb_writer=None,
                                    translations=val_translations,
                                    print_translations=False,
                                    prefix=None)

        # Create path
        eval_name = test_domain
        eval_path = os.path.join(basepath, DATASET_EVAL_NAME, model_name,
                                 eval_name)
        Path(eval_path).mkdir(parents=True, exist_ok=True)

        # Generate them
        metrics = {"beams": {}}
        for beam in BEAMS:
            print(f"Computing beam width={beam}...")

            # Create output path
            output_path = os.path.join(eval_path, f"beam{beam}")
            Path(output_path).mkdir(parents=True, exist_ok=True)

            print(f"\t- Generating translations for: {test_domain}...")
            # Get translations (using beam search)
            src_dec_all, hyp_dec_all, ref_dec_all = base.get_translations(
                test_loader,
                model,
                device=DEVICE1,
                max_length=MAX_LENGTH,
                beam_width=beam)
            # Print translations
            if PRINT_TRANSLATIONS:
                helpers.print_translations(hyp_dec_all,
                                           ref_dec_all,
                                           src_dec_all,
                                           limit=50,
                                           randomized=False)

            # Compute scores
            metrics["beams"][f"beam{beam}"] = base.compute_metrics(
                hyp_dec_all, ref_dec_all, use_ter=False)
            print(
                f'Translation scores (beam_width={beam}; max_length={MAX_LENGTH})'
            )
            print(
                f'\t- Sacrebleu (bleu): {metrics[f"beam{beam}"]["sacrebleu_bleu"]:.2f}'
            )
            # print(f'\t- Sacrebleu (ter): {metrics[f"beam{beam}"]["sacrebleu_ter"]:.2f}')
            print(
                f'\t- Sacrebleu (chrf): {metrics[f"beam{beam}"]["sacrebleu_chrf"]:.2f}'
            )
            print(
                f'\t- Torchtext (bleu): {metrics[f"beam{beam}"]["torchtext_bleu"]:.2f}'
            )

            # Save translations to file
            with open(os.path.join(output_path, 'src.txt'), 'w') as f:
                f.writelines("%s\n" % s for s in src_dec_all)
            with open(os.path.join(output_path, 'hyp.txt'), 'w') as f:
                f.writelines("%s\n" % s for s in hyp_dec_all)
            with open(os.path.join(output_path, 'ref.txt'), 'w') as f:
                f.writelines("%s\n" % s for s in ref_dec_all)
            print(f"Translations written! => Path: {output_path}")

            # Generate beam metrics
            print(f"\t- Generating translations for: {test_domain}...")
            subprocess.call(
                ['sh', './scripts/6_sacrebleu.sh', eval_path, output_path])
            metrics["beams"].update(get_beam_scores(output_path, beam))

        # Save metrics to file
        with open(os.path.join(eval_path, 'beam_metrics.json'), 'w') as f:
            json.dump(metrics, f)
        print("Metrics saved!")
        print(
            "\t- To get BLEU/CHRF/TER use: 'cat hyp.txt | sacrebleu ref.txt --metrics bleu'"
        )
        print("\t- To get CHRF use: 'chrf -R ref.txt -H hyp.txt'")

        print("************************************************************")
        epoch_hours, epoch_mins, epoch_secs = helpers.epoch_time(
            start_time, end_time=time.time())
        print(f'Time experiment: {epoch_hours}h {epoch_mins}m {epoch_secs}s')
        print("************************************************************")
        print("Done!")
def run_experiment(datapath,
                   src,
                   trg,
                   model_name,
                   domain=None,
                   smart_batch=False):
    start_time = time.time()

    ###########################################################################
    ###########################################################################

    wandb.init(project=WANDB_PROJECT, entity='salvacarrion', reinit=True)
    config = wandb.config
    config.model_name = MODEL_NAME
    config.domain = domain
    config.max_epochs = MAX_EPOCHS
    config.learning_rate = LEARNING_RATE
    config.batch_size = BATCH_SIZE
    config.max_tokens = MAX_TOKENS
    config.warmup_updates = WARMUP_UPDATES
    config.patience = PATIENCE
    config.acc_gradients = ACC_GRADIENTS
    config.weight_decay = WEIGHT_DECAY
    config.clip_gradients = CLIP_GRADIENTS
    config.multigpu = MULTIGPU
    config.device1 = str(DEVICE1)
    config.device2 = str(DEVICE2)
    config.num_workers = NUM_WORKERS
    config.tok_model = TOK_MODEL
    config.tok_size = TOK_SIZE
    config.tok_folder = TOK_FOLDER
    config.lowercase = LOWERCASE
    config.truncate = TRUNCATE
    config.max_length_truncate = MAX_LENGTH_TRUNC
    config.sampler_name = str(SAMPLER_NAME)
    print(config)
    ###########################################################################
    ###########################################################################

    checkpoint_path = os.path.join(datapath, DATASET_CHECKPOINT_NAME,
                                   f"{model_name}_{domain}_acc")

    # Load tokenizers
    src_tok, trg_tok = helpers.get_tokenizers(os.path.join(
        datapath, DATASET_TOK_NAME, TOK_FOLDER),
                                              src,
                                              trg,
                                              tok_model=TOK_MODEL,
                                              lower=LOWERCASE,
                                              truncation=TRUNCATE,
                                              max_length=MAX_LENGTH_TRUNC)

    # Load dataset
    datapath_clean = DATASET_CLEAN_SORTED_NAME if smart_batch else DATASET_CLEAN_NAME
    if TOK_MODEL == "bpe":  # Do not preprocess again when using bpe
        src_tok.apply_bpe = False
        trg_tok.apply_bpe = False
        datapath_clean = os.path.join(DATASET_TOK_NAME, TOK_FOLDER)

    train_ds = TranslationDataset(os.path.join(datapath, datapath_clean),
                                  src_tok, trg_tok, "train")
    val_ds = TranslationDataset(os.path.join(datapath, datapath_clean),
                                src_tok, trg_tok, "val")

    # Build dataloaders
    kwargs_train = {}
    kwargs_val = {}
    if SAMPLER_NAME == "bucket":
        train_sampler = BucketBatchSampler(
            SequentialSampler(train_ds),
            batch_size=BATCH_SIZE,
            drop_last=False,
            sort_key=lambda i: len_func(train_ds, i))
        val_sampler = BucketBatchSampler(
            SequentialSampler(val_ds),
            batch_size=BATCH_SIZE,
            drop_last=False,
            sort_key=lambda i: len_func(val_ds, i))
    elif SAMPLER_NAME == "maxtokens":
        train_sampler = MaxTokensBatchSampler(
            SequentialSampler(train_ds),
            shuffle=True,
            batch_size=BATCH_SIZE,
            max_tokens=MAX_TOKENS,
            drop_last=False,
            sort_key=lambda i: len_func(train_ds, i))
        val_sampler = MaxTokensBatchSampler(
            SequentialSampler(val_ds),
            shuffle=False,
            batch_size=BATCH_SIZE,
            max_tokens=MAX_TOKENS,
            drop_last=False,
            sort_key=lambda i: len_func(val_ds, i))
    else:
        train_sampler = val_sampler = None
        kwargs_train = {"batch_size": BATCH_SIZE, "shuffle": True}
        kwargs_val = {"batch_size": BATCH_SIZE, "shuffle": False}

    # Define dataloader
    train_loader = DataLoader(
        train_ds,
        num_workers=NUM_WORKERS,
        collate_fn=lambda x: TranslationDataset.collate_fn(x, MAX_TOKENS),
        pin_memory=True,
        batch_sampler=train_sampler,
        **kwargs_train)
    val_loader = DataLoader(
        val_ds,
        num_workers=NUM_WORKERS,
        collate_fn=lambda x: TranslationDataset.collate_fn(x, MAX_TOKENS),
        pin_memory=True,
        batch_sampler=val_sampler,
        **kwargs_val)

    # Instantiate model #1
    model = Transformer(d_model=256,
                        enc_layers=3,
                        dec_layers=3,
                        enc_heads=8,
                        dec_heads=8,
                        enc_dff_dim=512,
                        dec_dff_dim=512,
                        enc_dropout=0.1,
                        dec_dropout=0.1,
                        max_src_len=2000,
                        max_trg_len=2000,
                        src_tok=src_tok,
                        trg_tok=trg_tok,
                        static_pos_emb=True)  #.to(DEVICE1)
    model.apply(initialize_weights)
    print(f'The model has {model.count_parameters():,} trainable parameters')
    criterion = nn.CrossEntropyLoss(
        ignore_index=trg_tok.word2idx[trg_tok.PAD_WORD])
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Load weights
    # checkpoint_path = os.path.join(datapath, DATASET_CHECKPOINT_NAME, "transformer_multi30k_best_new.pt")
    # print(f"Loading weights from: {checkpoint_path}")
    # model.load_state_dict(torch.load(checkpoint_path))

    # Tensorboard (it needs some epochs to start working ~10-20)
    tb_writer = SummaryWriter(
        os.path.join(datapath, DATASET_LOGS_NAME, f"{model_name}"))
    wandb.watch(model)

    # Prepare model and data for acceleration
    model, optimizer, train_loader, val_loader = accelerator.prepare(
        model, optimizer, train_loader, val_loader)

    # Train and validate model
    fit(model,
        optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=MAX_EPOCHS,
        criterion=criterion,
        checkpoint_path=checkpoint_path,
        src_tok=src_tok,
        trg_tok=trg_tok,
        tb_writer=tb_writer)

    print("************************************************************")
    epoch_hours, epoch_mins, epoch_secs = helpers.epoch_time(
        start_time, end_time=time.time())
    print(f'Time experiment: {epoch_hours}h {epoch_mins}m {epoch_secs}s')
    print("************************************************************")
    print("Done!")
def log_progress(epoch_i,
                 start_time,
                 tr_loss=None,
                 val_loss=None,
                 tb_writer=None,
                 translations=None,
                 print_translations=True,
                 prefix=None,
                 **kwargs):
    metrics = {}
    if tr_loss:
        metrics["train"] = {
            "loss": tr_loss,
            "ppl": math.exp(tr_loss),
        }

    if val_loss:
        metrics["val"] = {
            "loss": val_loss,
            "ppl": math.exp(val_loss),
        }

    # Get additional metrics
    if translations:
        src_dec_all, hyp_dec_all, ref_dec_all = translations

        if val_loss:
            val_metrics = compute_metrics(hyp_dec_all, ref_dec_all, **kwargs)
            metrics["val"].update(val_metrics)

        # Print translations
        if print_translations:
            helpers.print_translations(hyp_dec_all,
                                       ref_dec_all,
                                       src_dec_all,
                                       limit=50)

    # Print stuff
    end_time = time.time()
    epoch_hours, epoch_mins, epoch_secs = helpers.epoch_time(
        start_time, end_time)
    print("------------------------------------------------------------")
    print(
        f'Epoch: {epoch_i + 1:02} | Time: {epoch_mins}m {epoch_secs}s | [Prefix: {prefix}]'
    )
    if tr_loss:
        print(
            f'\t- Train Loss: {metrics["train"]["loss"]:.3f} | Train PPL: {metrics["train"]["ppl"]:.3f}'
        )
    if val_loss:
        extra_metrics = [
            f"Val {k.lower()}: {v:.3f}" for k, v in metrics["val"].items()
            if k not in {"loss", "ppl"}
        ]
        print(
            f'\t- Val Loss: {metrics["val"]["loss"]:.3f} | Val PPL: {metrics["val"]["ppl"]:.3f} | '
            + " | ".join(extra_metrics))
    print("------------------------------------------------------------")

    # Tensorboard
    if tb_writer:
        prefix = f"{prefix}_" if prefix else ""
        for split in list(metrics.keys()):
            for k, v in metrics[split].items():
                tb_writer.add_scalar(f'{prefix}{split}_{k.lower()}', v,
                                     epoch_i + 1)
                wandb.log({f'{prefix}{split}_{k.lower()}': v})

    return metrics