def wsj_pieces(args): wsj = module_from_file("wsj", os.path.join(root_dir, "datasets/wsj.py")) # Load the 20k open vocabulary: # Expects the original 20k vocab to be copied from # "csr_2_comp/13-34.1/wsj1/doc/lng_modl/vocab/wlist20o.nvp" # to "<data_dir>/vocab20ko.txt" vocab_file = os.path.join(args.data_dir, "vocab20ko.txt") with open(vocab_file, 'r') as fid: vocab = [l.strip().lower() for l in fid if l[0] != "#"] json_set_pieces(args, wsj, vocab)
def iamdb_pieces(args): iamdb = module_from_file("iamdb", os.path.join(root_dir, "datasets/iamdb.py")) forms = iamdb.load_metadata(args.data_dir, "▁") ds_keys = set() for _, v in iamdb.SPLITS.items(): for ds in v: with open(os.path.join(args.data_dir, f"{ds}.txt"), "r") as fid: ds_keys.update(l.strip() for l in fid) # Train sentencepiece model only on the training set text = [l["text"] for _, lines in forms.items() for l in lines if l["key"] not in ds_keys] num_pieces = args.num_pieces sp = train_spm_model( iter(text), num_pieces + 1, # to account for <unk> user_symbols=["/"], # added so token is in the output set ) vocab = sorted(set(w for t in text for w in t.split("▁") if w)) assert 'MOVE' in vocab save_pieces(sp, num_pieces, args.output_prefix, vocab)
def train(world_rank, args): # setup logging level = logging.INFO if world_rank != 0: level = logging.CRITICAL logging.getLogger().setLevel(level) with open(args.config, "r") as fid: config = json.load(fid) logging.info("Using the config \n{}".format(json.dumps(config))) is_distributed_train = False if args.world_size > 1: is_distributed_train = True torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=world_rank, ) if not args.disable_cuda: device = torch.device("cuda") torch.cuda.set_device(world_rank) else: device = torch.device("cpu") # seed everything: seed = config.get("seed", None) if seed is not None: torch.manual_seed(seed) # setup data loaders: logging.info("Loading dataset ...") dataset = config["data"]["dataset"] if not os.path.exists(f"datasets/{dataset}.py"): raise ValueError(f"Unknown dataset {dataset}") dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py") input_size = config["data"]["num_features"] data_path = config["data"]["data_path"] preprocessor = dataset.Preprocessor( data_path, num_features=input_size, tokens_path=config["data"].get("tokens", None), lexicon_path=config["data"].get("lexicon", None), use_words=config["data"].get("use_words", False), prepend_wordsep=config["data"].get("prepend_wordsep", False), supervised=config["data"].get("supervised", True), level=config["data"].get("level", "phone")) trainset = dataset.Dataset(data_path, preprocessor, split="train", augment=True) valset = dataset.Dataset(data_path, preprocessor, split="validation") train_loader = utils.data_loader(trainset, config, world_rank, args.world_size, split='train') val_loader = utils.data_loader(valset, config, world_rank, args.world_size, split='validation') # setup criterion, model: logging.info("Loading model ...") criterion, output_size = models.load_criterion( config.get("criterion_type", "ctc"), preprocessor, config.get("criterion", {}), ) criterion = criterion.to(device) model = models.load_model(config["model_type"], input_size, output_size, config["model"]).to(device) if args.restore: models.load_from_checkpoint(model, criterion, args.checkpoint_path, True) n_params = sum(p.numel() for p in model.parameters()) logging.info("Training {} model with {:,} parameters.".format( config["model_type"], n_params)) # Store base module, criterion for saving checkpoints base_model = model base_criterion = criterion # `decode` cannot be called on DDP module if is_distributed_train: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[world_rank]) if len(list(criterion.parameters())) > 0: criterion = torch.nn.parallel.DistributedDataParallel( criterion, device_ids=[world_rank]) epochs = config["optim"]["epochs"] lr = config["optim"]["learning_rate"] step_size = config["optim"]["step_size"] max_grad_norm = config["optim"].get("max_grad_norm", None) # run training: logging.info("Starting training ...") scale = 0.5**(args.last_epoch // step_size) params = [{ "params": model.parameters(), "initial_lr": lr * scale, "lr": lr * scale }] if len(list(criterion.parameters())) > 0: crit_params = {"params": criterion.parameters()} crit_lr = config["optim"].get("crit_learning_rate", lr) crit_params['lr'] = crit_lr * scale crit_params['initial_lr'] = crit_lr * scale params.append(crit_params) optimizer = torch.optim.SGD(params) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=0.5, last_epoch=args.last_epoch, ) min_val_loss = float("inf") min_val_cer = float("inf") min_val_wer = float("inf") Timer = utils.CudaTimer if device.type == "cuda" else utils.Timer timers = Timer([ "ds_fetch", # dataset sample fetch "model_fwd", # model forward "crit_fwd", # criterion forward "bwd", # backward (model + criterion) "optim", # optimizer step "metrics", # viterbi, cer "train_total", # total training "test_total", # total testing ]) num_updates = 0 for epoch in range(args.last_epoch, epochs): logging.info("Epoch {} started. ".format(epoch + 1)) model.train() criterion.train() start_time = time.time() meters = utils.Meters() timers.reset() timers.start("train_total").start("ds_fetch") for inputs, targets in train_loader: timers.stop("ds_fetch").start("model_fwd") optimizer.zero_grad() outputs = model(inputs.to(device)) timers.stop("model_fwd").start("crit_fwd") loss = criterion(outputs, targets) timers.stop("crit_fwd").start("bwd") loss.backward() timers.stop("bwd").start("optim") if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_( itertools.chain(model.parameters(), criterion.parameters()), max_grad_norm, ) optimizer.step() num_updates += 1 timers.stop("optim").start("metrics") meters.loss += loss.item() * len(targets) meters.num_samples += len(targets) tokens_dist, words_dist, n_tokens, n_words = compute_edit_distance( base_criterion.viterbi(outputs), targets, preprocessor) meters.edit_distance_tokens += tokens_dist meters.num_tokens += n_tokens meters.edit_distance_words += words_dist meters.num_words += n_words timers.stop("metrics").start("ds_fetch") timers.stop("ds_fetch").stop("train_total") epoch_time = time.time() - start_time if args.world_size > 1: meters.sync() logging.info( "Epoch {} complete. " "nUpdates {}, Loss {:.3f}, CER {:.3f}, WER {:.3f}," " Time {:.3f} (s), LR {:.3f}".format( epoch + 1, num_updates, meters.avg_loss, meters.cer, meters.wer, epoch_time, scheduler.get_last_lr()[0], ), ) logging.info("Evaluating validation set..") timers.start("test_total") val_loss, val_cer, val_wer = test(model, base_criterion, val_loader, preprocessor, device, args.world_size, checkpoint_path=args.checkpoint_path) timers.stop("test_total") if world_rank == 0: checkpoint( base_model, base_criterion, args.checkpoint_path, (val_cer < min_val_cer), ) min_val_loss = min(val_loss, min_val_loss) min_val_cer = min(val_cer, min_val_cer) min_val_wer = min(val_wer, min_val_wer) logging.info( "Validation Set: Loss {:.3f}, CER {:.3f}, WER {:.3f}, " "Best Loss {:.3f}, Best CER {:.3f}, Best WER {:.3f}".format( val_loss, val_cer, val_wer, min_val_loss, min_val_cer, min_val_wer), ) logging.info("Timing Info: " + ", ".join([ "{} : {:.2f}ms".format(k, v * 1000.0) for k, v in timers.value().items() ])) scheduler.step() start_time = time.time() if is_distributed_train: torch.distributed.destroy_process_group()
def train(args, world_rank=0): # setup logging level = logging.INFO logging.getLogger().setLevel(level) if not args.disable_cuda: device = torch.device("cuda") else: device = torch.device("cpu") with open(args.config, "r") as fid: config = json.load(fid) logging.info("Using the config \n{}".format(json.dumps(config))) # seed everything: seed = config.get("seed", None) if seed is not None: torch.manual_seed(seed) # setup data loaders: logging.info("Loading dataset ...") dataset = config["data"]["dataset"] if not os.path.exists(f"datasets/{dataset}.py"): raise ValueError(f"Unknown dataset {dataset}") dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py") input_size = config["data"]["num_features"] data_path = config["data"]["data_path"] batch_size = config["optim"]["batch_size"] preprocessor = dataset.Preprocessor( data_path, num_features=input_size, tokens_path=config["data"].get("tokens", None), lexicon_path=config["data"].get("lexicon", None), use_words=config["data"].get("use_words", False), prepend_wordsep=config["data"].get("prepend_wordsep", False), supervised=config["data"].get("supervised", False), level=config["data"].get("level", "phone") ) output_size = preprocessor.num_tokens trainset = dataset.Dataset(data_path, preprocessor, split="train", augment=True) valset = dataset.Dataset(data_path, preprocessor, split="test", augment=True) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) # setup criterion, model: logging.info("Loading model ...") criterion = nn.CrossEntropyLoss() model = PositionDependentUnigramBottleneck(input_size, output_size,config["model_type"], **config['model']) if args.restore: model_checkpoint = os.path.join(args.checkpoint_path, "model.checkpoint") model.load_state_dict(torch.load(model_checkpoint)) n_params = sum(p.numel() for p in model.parameters()) logging.info( "Training {} model with {:,} parameters.".format(config["model_type"], n_params) ) # Store base module, criterion for saving checkpoints base_model = model if not isinstance(model, torch.nn.DataParallel): model = nn.DataParallel(model) epochs = config["optim"]["epochs"] lr = config["optim"]["learning_rate"] step_size = config["optim"]["step_size"] max_grad_norm = config["optim"].get("max_grad_norm", None) # run training: logging.info("Starting training ...") scale = 0.5 ** (args.last_epoch // step_size) params = [{"params" : model.parameters(), "initial_lr" : lr * scale, "lr" : lr * scale}] optimizer = torch.optim.SGD(params) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=0.5, last_epoch=args.last_epoch, ) max_val_acc = float("-inf") Timer = utils.CudaTimer if device.type == "cuda" else utils.Timer timers = Timer( [ "ds_fetch", # dataset sample fetch "model_fwd", # model forward "crit_fwd", # criterion forward "bwd", # backward (model + criterion) "optim", # optimizer step "metrics", # viterbi, cer "train_total", # total training "test_total", # total testing ] ) num_updates = 0 for epoch in range(args.last_epoch, epochs): logging.info("Epoch {} started. ".format(epoch + 1)) model.train() start_time = time.time() meters = utils.IBMeters() timers.reset() timers.start("train_total").start("ds_fetch") for i, (inputs, targets, input_masks) in enumerate(train_loader): if args.eval_only: continue timers.stop("ds_fetch").start("model_fwd") optimizer.zero_grad() in_scores, trg_scores = model(inputs.to(device), input_masks) prediction_loss = criterion(trg_scores, targets.to(device)) loss, I_ZX, I_ZY = model.module.calculate_loss(in_scores, prediction_loss) timers.stop("model_fwd").start("crit_fwd") timers.stop("crit_fwd").start("bwd") loss.backward() timers.stop("bwd").start("optim") if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_( model.parameters(), max_grad_norm, ) optimizer.step() num_updates += 1 timers.stop("optim").start("metrics") meters.loss += loss.item() * len(targets) meters.num_samples += len(targets) meters.I_ZX += I_ZX.item() * len(targets) meters.I_ZY += I_ZY.item() * len(targets) meters.num_tokens += len(targets) timers.stop("metrics").start("ds_fetch") if i % 1000 == 0: info = 'Itr {} {meters.loss:.3f} ({meters.avg_loss:.3f})'.format(i, meters=meters) print(info) timers.stop("ds_fetch").stop("train_total") epoch_time = time.time() - start_time logging.info( "Epoch {} complete. " "nUpdates {}, Loss {:.3f} " " Time {:.3f} (s), LR {:.3f}".format( epoch + 1, num_updates, meters.avg_loss, epoch_time, scheduler.get_last_lr()[0], ), ) if epoch % 1 == 0: logging.info("Evaluating validation set..") timers.start("test_total") val_acc = test( model, val_loader, device, args.checkpoint_path ) timers.stop("test_total") checkpoint( base_model, args.checkpoint_path, (val_acc > max_val_acc), ) max_val_acc = max(val_acc, max_val_acc) logging.info( "Validation Set: WER {:.3f}, " "Best WER {:.3f}".format( 1-val_acc, 1-max_val_acc ), ) logging.info( "Timing Info: " + ", ".join( [ "{} : {:.2f}ms".format(k, v * 1000.0) for k, v in timers.value().items() ] ) ) scheduler.step() start_time = time.time()
def test(args): with open(args.config, "r") as fid: config = json.load(fid) if not args.disable_cuda: device = torch.device("cuda") else: device = torch.device("cpu") dataset = config["data"]["dataset"] if not os.path.exists(f"datasets/{dataset}.py"): raise ValueError(f"Unknown dataset {dataset}") dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py") input_size = config["data"]["num_features"] data_path = config["data"]["data_path"] preprocessor = dataset.Preprocessor( data_path, num_features=input_size, tokens_path=config["data"].get("tokens", None), lexicon_path=config["data"].get("lexicon", None), use_words=config["data"].get("use_words", False), prepend_wordsep=config["data"].get("prepend_wordsep", False), ) data = dataset.Dataset(data_path, preprocessor, split=args.split) loader = utils.data_loader(data, config) criterion, output_size = models.load_criterion( config.get("criterion_type", "ctc"), preprocessor, config.get("criterion", {}), ) criterion = criterion.to(device) model = models.load_model(config["model_type"], input_size, output_size, config["model"]).to(device) models.load_from_checkpoint(model, criterion, args.checkpoint_path, args.load_last) model.eval() meters = utils.Meters() for inputs, targets in loader: outputs = model(inputs.to(device)) meters.loss += criterion(outputs, targets).item() * len(targets) meters.num_samples += len(targets) predictions = criterion.viterbi(outputs) for p, t in zip(predictions, targets): p, t = preprocessor.tokens_to_text(p), preprocessor.to_text(t) pw, tw = p.split(preprocessor.wordsep), t.split( preprocessor.wordsep) pw, tw = list(filter(None, pw)), list(filter(None, tw)) tokens_dist = editdistance.eval(p, t) words_dist = editdistance.eval(pw, tw) print("CER: {:.3f}".format(tokens_dist * 100.0 / len(t) if len(t) > 0 else 0)) print("WER: {:.3f}".format(words_dist * 100.0 / len(tw) if len(tw) > 0 else 0)) print("HYP:", "".join(p)) print("REF", "".join(t)) print("=" * 80) meters.edit_distance_tokens += tokens_dist meters.edit_distance_words += words_dist meters.num_tokens += len(t) meters.num_words += len(tw) print("Loss {:.3f}, CER {:.3f}, WER {:.3f}, ".format( meters.avg_loss, meters.cer, meters.wer))
def librispeech_pieces(args): # Train sentencepiece model only on the training set librispeech = module_from_file("librispeech", os.path.join(root_dir, "datasets/librispeech.py")) json_set_pieces(args, librispeech)