print('Finished Training!')


if __name__ == "__main__":
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)

    bert_model = "Musixmatch/umberto-commoncrawl-cased-v1"
    # bert_model = "idb-ita/gilberto-uncased-from-camembert"
    num_classes = 11

    bert = UmbertoCustom(bert_model=bert_model, num_classes=num_classes).to(device)

    train_iter, valid_iter, test_iter = get_train_valid_test_fine(bert_model=bert_model, max_seq_lenght=512)

    opt = optim.Adam(bert.parameters(), lr=2e-5)
    init_time = time.time()

    train(model=bert, optimizer=opt, train_loader=train_iter, valid_loader=valid_iter, num_epochs=5,
          eval_every=len(train_iter) // 2, file_path="../data/models/")

    tot_time = time.time() - init_time
    print("time taken:", int(tot_time // 60), "minutes", int(tot_time % 60), "seconds")

    best_model = UmbertoCustom(bert_model=bert_model, num_classes=num_classes).to(device)

    load_checkpoint("../data/models" + '/model2.pt', best_model, device)

    evaluate(best_model, test_iter, num_classes, device)
Exemple #2
0
def main_func(params):
    assert params.data_mean != '', "-data_mean is required"
    assert params.data_sd != '', "-data_sd is required"
    params.data_mean = [float(m) for m in params.data_mean.split(',')]
    params.data_sd = [float(s) for s in params.data_sd.split(',')]

    if params.seed > -1:
        set_seed(params.seed)
    rnd_generator = torch.Generator(device='cpu') if params.seed > -1 else None

    # Setup image training data
    training_data, num_classes, class_weights = load_dataset(data_path=params.data_path, val_percent=params.val_percent, batch_size=params.batch_size, \
                                                             input_mean=params.data_mean, input_sd=params.data_sd, use_caffe=not params.not_caffe, \
                                                             train_workers=params.train_workers, val_workers=params.val_workers, balance_weights=params.balance_classes, \
                                                             rnd_generator=rnd_generator)


    # Setup model definition
    cnn, is_start_model, base_model = setup_model(params.model_file, num_classes=num_classes, base_model=params.base_model, pretrained=not params.reset_weights)

    if params.optimizer == 'sgd':
        optimizer = optim.SGD(cnn.parameters(), lr=params.lr, momentum=0.9)
    elif params.optimizer == 'adam':
        optimizer = optim.Adam(cnn.parameters(), lr=params.lr)

    lrscheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.96)

    if params.balance_classes:
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights.to(params.use_device))
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Maybe delete braches
    if params.delete_branches and not is_start_model:
        try:
            cnn.remove_branches()
            has_branches = False
        except:
            has_branches = True
            pass
    else:
       has_branches = True


    # Load pretrained model weights
    start_epoch = 1
    if not params.reset_weights:
        cnn, optimizer, lrscheduler, start_epoch = load_checkpoint(cnn, params.model_file, optimizer, lrscheduler, num_classes, is_start_model=is_start_model)

    if params.delete_branches and is_start_model:
        try:
            cnn.remove_branches()
            has_branches = False
        except:
            has_branches = True
            pass
    else:
       has_branches = True


    # Maybe freeze some model layers
    main_layer_list = ['conv1', 'conv2', 'conv3', 'mixed3a', 'mixed3b', 'mixed4a', 'mixed4b', 'mixed4c', 'mixed4d', 'mixed4e', 'mixed5a', 'mixed5b']
    if params.freeze_to != 'none':
        for layer in main_layer_list:
            if params.freeze_to == layer:
                break
            for param in getattr(cnn, layer).parameters():
                param.requires_grad = False
    branch_layer_list = ['loss_conv', 'loss_fc', 'loss_classifier']
    if params.freeze_aux1_to != 'none' and has_branches:
        for layer in branch_layer_list:
            if params.freeze_aux1_to == layer:
                break
            for param in getattr(getattr(cnn, 'aux1'), layer).parameters():
                param.requires_grad = False
    if params.freeze_aux2_to != 'none' and has_branches:
        for layer in branch_layer_list:
            if params.freeze_aux2_to == layer:
                break
            for param in getattr(getattr(cnn, 'aux2'), layer).parameters():
                param.requires_grad = False


       # Optionally freeze/unfreeze specific layers and sub layers
    if params.toggle_layers != 'none':
        toggle_layers = [l.replace('\\', '/').replace('.', '/').split('/') for l in params.toggle_layers.split(',')]
        for layer in toggle_layers:
            if len(layer) == 2:
                for param in getattr(getattr(cnn, layer[0]), layer[1]).parameters():
                    param.requires_grad = False if param.requires_grad == True else False
            else:
                for param in getattr(cnn, layer[0]).parameters():
                    param.requires_grad = False if param.requires_grad == True else False


    n_learnable_params = sum(param.numel() for param in cnn.parameters() if param.requires_grad)
    print('Model has ' + "{:,}".format(n_learnable_params) + ' learnable parameters\n')


    cnn = cnn.to(params.use_device)
    if 'cuda' in params.use_device:
        if params.seed > -1:
            torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True


    save_info = [[params.data_mean, params.data_sd, 'BGR'], num_classes, has_branches, base_model]

    # Train model
    train_model(model=cnn, dataloaders=training_data, criterion=criterion, optimizer=optimizer, lrscheduler=lrscheduler, \
                num_epochs=params.num_epochs, start_epoch=start_epoch, save_epoch=params.save_epoch, output_name=params.output_name, \
                device=params.use_device, has_branches=has_branches, fc_only=False, num_classes=num_classes, individual_acc=params.individual_acc, \
                should_save_csv=params.save_csv, csv_path=params.csv_dir, save_info=save_info)
def main():
    # Init state params
    params = init_parms()
    device = params.get('device')

    # Loading the model, optimizer & criterion
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    model = torch.nn.DataParallel(model)
    logger.info(
        f'Model initialized with {get_model_size(model):.3f}M parameters')
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    load_checkpoint(model, optimizer, params)
    start_epoch = params['start_epoch']
    sup_criterion = CustomCTCLoss()

    # Validation progress bars defined here.
    pbar = ProgressBar(persist=True, desc="Loss")
    pbar_valid = ProgressBar(persist=True, desc="Validate")

    # load timer and best meter to keep track of state params
    timer = Timer(average=True)

    # load all the train data
    logger.info('Begining to load Datasets')
    trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                           'train-labelled-en')

    # form data loaders
    train = lmdbMultiDatasetTester(roots=[trainAirtelPaymentsPath],
                                   transform=image_val_transform)

    logger.info(f'loaded train & test dataset = {len(train)}')

    def train_update_function(engine, _):
        optimizer.zero_grad()
        imgs_sup, labels_sup, label_lengths, input_lengths = next(
            engine.state.train_loader_labbeled)
        imgs_sup = imgs_sup.to(device)
        labels_sup = labels_sup
        probs_sup = model(imgs_sup)
        sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths,
                                 input_lengths)
        sup_loss.backward()
        optimizer.step()
        return sup_loss.item()

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths, image_lengths = batch
        y_pred = model(img.to(device))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = sequence_to_string(labels_list[idx:idx + length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=train_batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               collate_fn=allign_collate)
    trainer = Engine(train_update_function)
    evaluator = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    iteration_log_step = int(0.33 * len(train_loader))
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_gamma,
        patience=int(config.epochs * 0.05),
        verbose=True,
        threshold_mode="abs",
        cooldown=int(config.epochs * 0.025),
        min_lr=1e-5)

    pbar.attach(trainer, output_transform=lambda x: {'loss': x})
    pbar_valid.attach(evaluator, ['wer', 'cer'],
                      event_name=Events.EPOCH_COMPLETED,
                      closing_event_name=Events.COMPLETED)
    timer.attach(trainer)

    @trainer.on(Events.STARTED)
    def set_init_epoch(engine):
        engine.state.epoch = params['start_epoch']
        logger.info(f'Initial epoch for trainer set to {engine.state.epoch}')

    @trainer.on(Events.EPOCH_STARTED)
    def set_model_train(engine):
        if hasattr(engine.state, 'train_loader_labbeled'):
            del engine.state.train_loader_labbeled
        engine.state.train_loader_labbeled = iter(train_loader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        if (engine.state.iteration % iteration_log_step
                == 0) and (engine.state.iteration > 0):
            engine.state.epoch += 1
            train.set_epochs(engine.state.epoch)
            model.eval()
            logger.info('Model set to eval mode')
            evaluator.run(train_loader)
            model.train()
            logger.info('Model set back to train mode')

    @trainer.on(Events.EPOCH_COMPLETED)
    def after_complete(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()
Exemple #4
0
def main(local_rank):
    params = init_parms(local_rank)
    device = params.get('device')
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    logger.info(
        f'Model initialized with {get_model_size(model):.3f}M parameters')
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    model = DistributedDataParallel(model,
                                    device_ids=[local_rank],
                                    output_device=local_rank,
                                    check_reduction=True)
    load_checkpoint(model, optimizer, params)
    print(f"Loaded model on {local_rank}")
    start_epoch = params['start_epoch']
    sup_criterion = CustomCTCLoss()
    unsup_criterion = UDALoss()
    if args.local_rank == 0:
        tb_logger = TensorboardLogger(log_dir=log_path)
        pbar = ProgressBar(persist=True, desc="Training")
        pbar_valid = ProgressBar(persist=True, desc="Validation Clean")
        pbar_valid_other = ProgressBar(persist=True, desc="Validation Other")
        pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel")
        pbar_valid_airtel_payments = ProgressBar(
            persist=True, desc="Validation Airtel Payments")
        timer = Timer(average=True)
        best_meter = params.get('best_stats', BestMeter())

    trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled')
    trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled')
    trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path,
                                        'train-labelled-en')
    trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en')
    trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                           'train-labelled-en')
    testCleanPath = os.path.join(lmdb_root_path, 'test-clean')
    testOtherPath = os.path.join(lmdb_root_path, 'test-other')
    testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en')
    testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                          'test-labelled-en')
    devOtherPath = os.path.join(lmdb_root_path, 'dev-other')

    train_clean = lmdbMultiDataset(roots=[
        trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath,
        trainAirtelPaymentsPath
    ],
                                   transform=image_train_transform)
    train_other = lmdbMultiDataset(roots=[devOtherPath],
                                   transform=image_train_transform)

    test_clean = lmdbMultiDataset(roots=[testCleanPath],
                                  transform=image_val_transform)
    test_other = lmdbMultiDataset(roots=[testOtherPath],
                                  transform=image_val_transform)
    test_airtel = lmdbMultiDataset(roots=[testAirtelPath],
                                   transform=image_val_transform)
    test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath],
                                            transform=image_val_transform)

    logger.info(
        f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)} examples'
    )

    def train_update_function(engine, _):
        optimizer.zero_grad()
        # Supervised gt, pred
        imgs_sup, labels_sup, label_lengths = next(
            engine.state.train_loader_labbeled)
        imgs_sup = imgs_sup.cuda(local_rank, non_blocking=True)
        labels_sup = labels_sup
        probs_sup = model(imgs_sup)

        # Unsupervised gt, pred
        # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled)
        # with torch.no_grad():
        #     probs_unsup = model(imgs_unsup.to(device))
        # probs_aug_unsup = model(augmented_imgs_unsup.to(device))

        sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths)
        # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup)

        # Blend supervised and unsupervised losses till unsupervision_warmup_epoch
        # alpha = get_alpha(engine.state.epoch)
        # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss)

        # final_loss = sup_loss
        sup_loss.backward()
        optimizer.step()

        return sup_loss.item()

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths = batch
        y_pred = model(img.cuda(local_rank, non_blocking=True))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = get_sentence(labels_list[idx:idx + length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    train_sampler_labbeled = torch.utils.data.distributed.DistributedSampler(
        train_clean, num_replicas=3, rank=args.local_rank)
    train_sampler_unlabbeled = torch.utils.data.distributed.DistributedSampler(
        train_other, num_replicas=3, rank=args.local_rank)
    test_sampler_clean = torch.utils.data.distributed.DistributedSampler(
        test_clean, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_other = torch.utils.data.distributed.DistributedSampler(
        test_other, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_airtel = torch.utils.data.distributed.DistributedSampler(
        test_airtel, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_airtel_payments = torch.utils.data.distributed.DistributedSampler(
        test_payments_airtel,
        num_replicas=3,
        rank=args.local_rank,
        shuffle=False)

    train_loader_labbeled_loader = torch.utils.data.DataLoader(
        train_clean,
        batch_size=train_batch_size // 3,
        sampler=train_sampler_labbeled,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    train_loader_unlabbeled_loader = torch.utils.data.DataLoader(
        train_other,
        batch_size=train_batch_size * 4,
        sampler=train_sampler_unlabbeled,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_clean = torch.utils.data.DataLoader(
        test_clean,
        batch_size=1,
        sampler=test_sampler_clean,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_other = torch.utils.data.DataLoader(
        test_other,
        batch_size=1,
        sampler=test_sampler_other,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_airtel = torch.utils.data.DataLoader(
        test_airtel,
        batch_size=1,
        sampler=test_sampler_airtel,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_airtel_payments = torch.utils.data.DataLoader(
        test_payments_airtel,
        batch_size=1,
        sampler=test_sampler_airtel_payments,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    trainer = Engine(train_update_function)
    iteration_log_step = int(0.33 * len(train_loader_labbeled_loader))
    evaluator_clean = Engine(validate_update_function)
    evaluator_other = Engine(validate_update_function)
    evaluator_airtel = Engine(validate_update_function)
    evaluator_airtel_payments = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    for name, metric in metrics.items():
        metric.attach(evaluator_clean, name)
        metric.attach(evaluator_other, name)
        metric.attach(evaluator_airtel, name)
        metric.attach(evaluator_airtel_payments, name)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_gamma,
        patience=int(config.epochs * 0.05),
        verbose=True,
        threshold_mode="abs",
        cooldown=int(config.epochs * 0.025),
        min_lr=1e-5)
    if args.local_rank == 0:
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             output_transform=lambda loss: {'loss': loss}),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=WeightsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=WeightsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=GradsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=GradsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_clean,
                         log_handler=OutputHandler(tag="validation_clean",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_other,
                         log_handler=OutputHandler(tag="validation_other",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_airtel,
                         log_handler=OutputHandler(tag="validation_airtel",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_airtel_payments,
                         log_handler=OutputHandler(
                             tag="validation_airtel_payments",
                             metric_names=["wer", "cer"],
                             another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        pbar.attach(trainer, output_transform=lambda x: {'loss': x})
        pbar_valid.attach(evaluator_clean, ['wer', 'cer'],
                          event_name=Events.EPOCH_COMPLETED,
                          closing_event_name=Events.COMPLETED)
        pbar_valid_other.attach(evaluator_other, ['wer', 'cer'],
                                event_name=Events.EPOCH_COMPLETED,
                                closing_event_name=Events.COMPLETED)
        pbar_valid_airtel.attach(evaluator_airtel, ['wer', 'cer'],
                                 event_name=Events.EPOCH_COMPLETED,
                                 closing_event_name=Events.COMPLETED)
        pbar_valid_airtel_payments.attach(evaluator_airtel_payments,
                                          ['wer', 'cer'],
                                          event_name=Events.EPOCH_COMPLETED,
                                          closing_event_name=Events.COMPLETED)
        timer.attach(trainer)

    @trainer.on(Events.STARTED)
    def set_init_epoch(engine):
        engine.state.epoch = params['start_epoch']
        logger.info(f'Initial epoch for trainer set to {engine.state.epoch}')

    @trainer.on(Events.EPOCH_STARTED)
    def set_model_train(engine):
        if hasattr(engine.state, 'train_loader_labbeled'):
            del engine.state.train_loader_labbeled
        engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader)
        # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        if (engine.state.iteration % iteration_log_step
                == 0) and (engine.state.iteration > 0):
            engine.state.epoch += 1
            train_clean.set_epochs(engine.state.epoch)
            train_other.set_epochs(engine.state.epoch)
            model.eval()
            logger.info('Model set to eval mode')
            evaluator_clean.run(test_loader_clean)
            evaluator_other.run(test_loader_other)
            evaluator_airtel.run(test_loader_airtel)
            evaluator_airtel_payments.run(test_loader_airtel_payments)
            model.train()
            logger.info('Model set back to train mode')

    if args.local_rank == 0:

        @evaluator_other.on(Events.EPOCH_COMPLETED)
        def save_checkpoints(engine):
            metrics = engine.state.metrics
            wer = metrics['wer']
            cer = metrics['cer']
            epoch = trainer.state.epoch
            scheduler.step(wer)
            save_checkpoint(model, optimizer, best_meter, wer, cer, epoch)
            best_meter.update(wer, cer, epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def after_complete(engine):
            logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
                engine.state.epoch, timer.value()))
            timer.reset()

    trainer.run(train_loader_labbeled_loader, max_epochs=epochs)
    if args.local_rank == 0:
        tb_logger.close()
Exemple #5
0
def main():
    params = init_parms()
    device = params.get('device')
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    model = torch.nn.DataParallel(model)
    model.eval()
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    load_checkpoint(model, optimizer, params)

    testCleanPath = os.path.join(lmdb_root_path, 'test-clean')
    testOtherPath = os.path.join(lmdb_root_path, 'test-other')

    test_clean = lmdbMultiDataset(roots=[testCleanPath])
    test_other = lmdbMultiDataset(roots=[testOtherPath])

    logger.info(
        f'Loaded Test Datasets, test_clean={len(test_clean)} & test_other={len(test_other)} examples'
    )

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths = batch
        y_pred = model(img.to(device))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable_beam(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = sequence_to_string(labels_list[idx:idx + length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    allign_collate_partial = partial(allign_collate, device=device)
    align_collate_unlabelled_partial = partial(align_collate_unlabelled,
                                               device=device)
    allign_collate_val_partial = partial(allign_collate_val, device=device)

    test_loader_clean = torch.utils.data.DataLoader(
        test_clean,
        batch_size=1,
        shuffle=False,
        num_workers=config.workers,
        pin_memory=False,
        collate_fn=allign_collate_val_partial)
    test_loader_other = torch.utils.data.DataLoader(
        test_other,
        batch_size=1,
        shuffle=False,
        num_workers=config.workers,
        pin_memory=False,
        collate_fn=allign_collate_val_partial)
    evaluator_clean = Engine(validate_update_function)
    evaluator_other = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    for name, metric in metrics.items():
        metric.attach(evaluator_clean, name)
        metric.attach(evaluator_other, name)

    evaluator_clean.run(test_loader_clean)
    evaluator_other.run(test_loader_other)

    metrics_clean = evaluator_clean.state.metrics
    metrics_other = evaluator_other.state.metrics

    print(
        f"Clean wer: {metrics_clean['wer']} Clean cer: {metrics_clean['cer']}")
    print(
        f"Other wer: {metrics_other['wer']} Other cer: {metrics_other['cer']}")