print('Finished Training!') if __name__ == "__main__": device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(device) bert_model = "Musixmatch/umberto-commoncrawl-cased-v1" # bert_model = "idb-ita/gilberto-uncased-from-camembert" num_classes = 11 bert = UmbertoCustom(bert_model=bert_model, num_classes=num_classes).to(device) train_iter, valid_iter, test_iter = get_train_valid_test_fine(bert_model=bert_model, max_seq_lenght=512) opt = optim.Adam(bert.parameters(), lr=2e-5) init_time = time.time() train(model=bert, optimizer=opt, train_loader=train_iter, valid_loader=valid_iter, num_epochs=5, eval_every=len(train_iter) // 2, file_path="../data/models/") tot_time = time.time() - init_time print("time taken:", int(tot_time // 60), "minutes", int(tot_time % 60), "seconds") best_model = UmbertoCustom(bert_model=bert_model, num_classes=num_classes).to(device) load_checkpoint("../data/models" + '/model2.pt', best_model, device) evaluate(best_model, test_iter, num_classes, device)
def main_func(params): assert params.data_mean != '', "-data_mean is required" assert params.data_sd != '', "-data_sd is required" params.data_mean = [float(m) for m in params.data_mean.split(',')] params.data_sd = [float(s) for s in params.data_sd.split(',')] if params.seed > -1: set_seed(params.seed) rnd_generator = torch.Generator(device='cpu') if params.seed > -1 else None # Setup image training data training_data, num_classes, class_weights = load_dataset(data_path=params.data_path, val_percent=params.val_percent, batch_size=params.batch_size, \ input_mean=params.data_mean, input_sd=params.data_sd, use_caffe=not params.not_caffe, \ train_workers=params.train_workers, val_workers=params.val_workers, balance_weights=params.balance_classes, \ rnd_generator=rnd_generator) # Setup model definition cnn, is_start_model, base_model = setup_model(params.model_file, num_classes=num_classes, base_model=params.base_model, pretrained=not params.reset_weights) if params.optimizer == 'sgd': optimizer = optim.SGD(cnn.parameters(), lr=params.lr, momentum=0.9) elif params.optimizer == 'adam': optimizer = optim.Adam(cnn.parameters(), lr=params.lr) lrscheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.96) if params.balance_classes: criterion = torch.nn.CrossEntropyLoss(weight=class_weights.to(params.use_device)) else: criterion = torch.nn.CrossEntropyLoss() # Maybe delete braches if params.delete_branches and not is_start_model: try: cnn.remove_branches() has_branches = False except: has_branches = True pass else: has_branches = True # Load pretrained model weights start_epoch = 1 if not params.reset_weights: cnn, optimizer, lrscheduler, start_epoch = load_checkpoint(cnn, params.model_file, optimizer, lrscheduler, num_classes, is_start_model=is_start_model) if params.delete_branches and is_start_model: try: cnn.remove_branches() has_branches = False except: has_branches = True pass else: has_branches = True # Maybe freeze some model layers main_layer_list = ['conv1', 'conv2', 'conv3', 'mixed3a', 'mixed3b', 'mixed4a', 'mixed4b', 'mixed4c', 'mixed4d', 'mixed4e', 'mixed5a', 'mixed5b'] if params.freeze_to != 'none': for layer in main_layer_list: if params.freeze_to == layer: break for param in getattr(cnn, layer).parameters(): param.requires_grad = False branch_layer_list = ['loss_conv', 'loss_fc', 'loss_classifier'] if params.freeze_aux1_to != 'none' and has_branches: for layer in branch_layer_list: if params.freeze_aux1_to == layer: break for param in getattr(getattr(cnn, 'aux1'), layer).parameters(): param.requires_grad = False if params.freeze_aux2_to != 'none' and has_branches: for layer in branch_layer_list: if params.freeze_aux2_to == layer: break for param in getattr(getattr(cnn, 'aux2'), layer).parameters(): param.requires_grad = False # Optionally freeze/unfreeze specific layers and sub layers if params.toggle_layers != 'none': toggle_layers = [l.replace('\\', '/').replace('.', '/').split('/') for l in params.toggle_layers.split(',')] for layer in toggle_layers: if len(layer) == 2: for param in getattr(getattr(cnn, layer[0]), layer[1]).parameters(): param.requires_grad = False if param.requires_grad == True else False else: for param in getattr(cnn, layer[0]).parameters(): param.requires_grad = False if param.requires_grad == True else False n_learnable_params = sum(param.numel() for param in cnn.parameters() if param.requires_grad) print('Model has ' + "{:,}".format(n_learnable_params) + ' learnable parameters\n') cnn = cnn.to(params.use_device) if 'cuda' in params.use_device: if params.seed > -1: torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True save_info = [[params.data_mean, params.data_sd, 'BGR'], num_classes, has_branches, base_model] # Train model train_model(model=cnn, dataloaders=training_data, criterion=criterion, optimizer=optimizer, lrscheduler=lrscheduler, \ num_epochs=params.num_epochs, start_epoch=start_epoch, save_epoch=params.save_epoch, output_name=params.output_name, \ device=params.use_device, has_branches=has_branches, fc_only=False, num_classes=num_classes, individual_acc=params.individual_acc, \ should_save_csv=params.save_csv, csv_path=params.csv_dir, save_info=save_info)
def main(): # Init state params params = init_parms() device = params.get('device') # Loading the model, optimizer & criterion model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) model = torch.nn.DataParallel(model) logger.info( f'Model initialized with {get_model_size(model):.3f}M parameters') optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) load_checkpoint(model, optimizer, params) start_epoch = params['start_epoch'] sup_criterion = CustomCTCLoss() # Validation progress bars defined here. pbar = ProgressBar(persist=True, desc="Loss") pbar_valid = ProgressBar(persist=True, desc="Validate") # load timer and best meter to keep track of state params timer = Timer(average=True) # load all the train data logger.info('Begining to load Datasets') trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en') # form data loaders train = lmdbMultiDatasetTester(roots=[trainAirtelPaymentsPath], transform=image_val_transform) logger.info(f'loaded train & test dataset = {len(train)}') def train_update_function(engine, _): optimizer.zero_grad() imgs_sup, labels_sup, label_lengths, input_lengths = next( engine.state.train_loader_labbeled) imgs_sup = imgs_sup.to(device) labels_sup = labels_sup probs_sup = model(imgs_sup) sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths, input_lengths) sup_loss.backward() optimizer.step() return sup_loss.item() @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths, image_lengths = batch y_pred = model(img.to(device)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) train_loader = torch.utils.data.DataLoader(train, batch_size=train_batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) trainer = Engine(train_update_function) evaluator = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} iteration_log_step = int(0.33 * len(train_loader)) for name, metric in metrics.items(): metric.attach(evaluator, name) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config.lr_gamma, patience=int(config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5) pbar.attach(trainer, output_transform=lambda x: {'loss': x}) pbar_valid.attach(evaluator, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) timer.attach(trainer) @trainer.on(Events.STARTED) def set_init_epoch(engine): engine.state.epoch = params['start_epoch'] logger.info(f'Initial epoch for trainer set to {engine.state.epoch}') @trainer.on(Events.EPOCH_STARTED) def set_model_train(engine): if hasattr(engine.state, 'train_loader_labbeled'): del engine.state.train_loader_labbeled engine.state.train_loader_labbeled = iter(train_loader) @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(engine): if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0): engine.state.epoch += 1 train.set_epochs(engine.state.epoch) model.eval() logger.info('Model set to eval mode') evaluator.run(train_loader) model.train() logger.info('Model set back to train mode') @trainer.on(Events.EPOCH_COMPLETED) def after_complete(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() trainer.run(train_loader, max_epochs=epochs) tb_logger.close()
def main(local_rank): params = init_parms(local_rank) device = params.get('device') model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) logger.info( f'Model initialized with {get_model_size(model):.3f}M parameters') optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, check_reduction=True) load_checkpoint(model, optimizer, params) print(f"Loaded model on {local_rank}") start_epoch = params['start_epoch'] sup_criterion = CustomCTCLoss() unsup_criterion = UDALoss() if args.local_rank == 0: tb_logger = TensorboardLogger(log_dir=log_path) pbar = ProgressBar(persist=True, desc="Training") pbar_valid = ProgressBar(persist=True, desc="Validation Clean") pbar_valid_other = ProgressBar(persist=True, desc="Validation Other") pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel") pbar_valid_airtel_payments = ProgressBar( persist=True, desc="Validation Airtel Payments") timer = Timer(average=True) best_meter = params.get('best_stats', BestMeter()) trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled') trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled') trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path, 'train-labelled-en') trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en') trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en') testCleanPath = os.path.join(lmdb_root_path, 'test-clean') testOtherPath = os.path.join(lmdb_root_path, 'test-other') testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en') testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'test-labelled-en') devOtherPath = os.path.join(lmdb_root_path, 'dev-other') train_clean = lmdbMultiDataset(roots=[ trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath, trainAirtelPaymentsPath ], transform=image_train_transform) train_other = lmdbMultiDataset(roots=[devOtherPath], transform=image_train_transform) test_clean = lmdbMultiDataset(roots=[testCleanPath], transform=image_val_transform) test_other = lmdbMultiDataset(roots=[testOtherPath], transform=image_val_transform) test_airtel = lmdbMultiDataset(roots=[testAirtelPath], transform=image_val_transform) test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath], transform=image_val_transform) logger.info( f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)} examples' ) def train_update_function(engine, _): optimizer.zero_grad() # Supervised gt, pred imgs_sup, labels_sup, label_lengths = next( engine.state.train_loader_labbeled) imgs_sup = imgs_sup.cuda(local_rank, non_blocking=True) labels_sup = labels_sup probs_sup = model(imgs_sup) # Unsupervised gt, pred # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled) # with torch.no_grad(): # probs_unsup = model(imgs_unsup.to(device)) # probs_aug_unsup = model(augmented_imgs_unsup.to(device)) sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths) # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup) # Blend supervised and unsupervised losses till unsupervision_warmup_epoch # alpha = get_alpha(engine.state.epoch) # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss) # final_loss = sup_loss sup_loss.backward() optimizer.step() return sup_loss.item() @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths = batch y_pred = model(img.cuda(local_rank, non_blocking=True)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = get_sentence(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) train_sampler_labbeled = torch.utils.data.distributed.DistributedSampler( train_clean, num_replicas=3, rank=args.local_rank) train_sampler_unlabbeled = torch.utils.data.distributed.DistributedSampler( train_other, num_replicas=3, rank=args.local_rank) test_sampler_clean = torch.utils.data.distributed.DistributedSampler( test_clean, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_other = torch.utils.data.distributed.DistributedSampler( test_other, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_airtel = torch.utils.data.distributed.DistributedSampler( test_airtel, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_airtel_payments = torch.utils.data.distributed.DistributedSampler( test_payments_airtel, num_replicas=3, rank=args.local_rank, shuffle=False) train_loader_labbeled_loader = torch.utils.data.DataLoader( train_clean, batch_size=train_batch_size // 3, sampler=train_sampler_labbeled, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) train_loader_unlabbeled_loader = torch.utils.data.DataLoader( train_other, batch_size=train_batch_size * 4, sampler=train_sampler_unlabbeled, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_clean = torch.utils.data.DataLoader( test_clean, batch_size=1, sampler=test_sampler_clean, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_other = torch.utils.data.DataLoader( test_other, batch_size=1, sampler=test_sampler_other, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_airtel = torch.utils.data.DataLoader( test_airtel, batch_size=1, sampler=test_sampler_airtel, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_airtel_payments = torch.utils.data.DataLoader( test_payments_airtel, batch_size=1, sampler=test_sampler_airtel_payments, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) trainer = Engine(train_update_function) iteration_log_step = int(0.33 * len(train_loader_labbeled_loader)) evaluator_clean = Engine(validate_update_function) evaluator_other = Engine(validate_update_function) evaluator_airtel = Engine(validate_update_function) evaluator_airtel_payments = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} for name, metric in metrics.items(): metric.attach(evaluator_clean, name) metric.attach(evaluator_other, name) metric.attach(evaluator_airtel, name) metric.attach(evaluator_airtel_payments, name) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config.lr_gamma, patience=int(config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5) if args.local_rank == 0: tb_logger.attach(trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_clean, log_handler=OutputHandler(tag="validation_clean", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_other, log_handler=OutputHandler(tag="validation_other", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel, log_handler=OutputHandler(tag="validation_airtel", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel_payments, log_handler=OutputHandler( tag="validation_airtel_payments", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) pbar.attach(trainer, output_transform=lambda x: {'loss': x}) pbar_valid.attach(evaluator_clean, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_other.attach(evaluator_other, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel.attach(evaluator_airtel, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel_payments.attach(evaluator_airtel_payments, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) timer.attach(trainer) @trainer.on(Events.STARTED) def set_init_epoch(engine): engine.state.epoch = params['start_epoch'] logger.info(f'Initial epoch for trainer set to {engine.state.epoch}') @trainer.on(Events.EPOCH_STARTED) def set_model_train(engine): if hasattr(engine.state, 'train_loader_labbeled'): del engine.state.train_loader_labbeled engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader) # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader) @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(engine): if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0): engine.state.epoch += 1 train_clean.set_epochs(engine.state.epoch) train_other.set_epochs(engine.state.epoch) model.eval() logger.info('Model set to eval mode') evaluator_clean.run(test_loader_clean) evaluator_other.run(test_loader_other) evaluator_airtel.run(test_loader_airtel) evaluator_airtel_payments.run(test_loader_airtel_payments) model.train() logger.info('Model set back to train mode') if args.local_rank == 0: @evaluator_other.on(Events.EPOCH_COMPLETED) def save_checkpoints(engine): metrics = engine.state.metrics wer = metrics['wer'] cer = metrics['cer'] epoch = trainer.state.epoch scheduler.step(wer) save_checkpoint(model, optimizer, best_meter, wer, cer, epoch) best_meter.update(wer, cer, epoch) @trainer.on(Events.EPOCH_COMPLETED) def after_complete(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() trainer.run(train_loader_labbeled_loader, max_epochs=epochs) if args.local_rank == 0: tb_logger.close()
def main(): params = init_parms() device = params.get('device') model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) model = torch.nn.DataParallel(model) model.eval() optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) load_checkpoint(model, optimizer, params) testCleanPath = os.path.join(lmdb_root_path, 'test-clean') testOtherPath = os.path.join(lmdb_root_path, 'test-other') test_clean = lmdbMultiDataset(roots=[testCleanPath]) test_other = lmdbMultiDataset(roots=[testOtherPath]) logger.info( f'Loaded Test Datasets, test_clean={len(test_clean)} & test_other={len(test_other)} examples' ) @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths = batch y_pred = model(img.to(device)) if np.random.rand() > 0.99: pred_sentences = get_most_probable_beam(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) allign_collate_partial = partial(allign_collate, device=device) align_collate_unlabelled_partial = partial(align_collate_unlabelled, device=device) allign_collate_val_partial = partial(allign_collate_val, device=device) test_loader_clean = torch.utils.data.DataLoader( test_clean, batch_size=1, shuffle=False, num_workers=config.workers, pin_memory=False, collate_fn=allign_collate_val_partial) test_loader_other = torch.utils.data.DataLoader( test_other, batch_size=1, shuffle=False, num_workers=config.workers, pin_memory=False, collate_fn=allign_collate_val_partial) evaluator_clean = Engine(validate_update_function) evaluator_other = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} for name, metric in metrics.items(): metric.attach(evaluator_clean, name) metric.attach(evaluator_other, name) evaluator_clean.run(test_loader_clean) evaluator_other.run(test_loader_other) metrics_clean = evaluator_clean.state.metrics metrics_other = evaluator_other.state.metrics print( f"Clean wer: {metrics_clean['wer']} Clean cer: {metrics_clean['cer']}") print( f"Other wer: {metrics_other['wer']} Other cer: {metrics_other['cer']}")