예제 #1
0
def wsj_pieces(args):
    wsj = module_from_file("wsj", os.path.join(root_dir, "datasets/wsj.py"))
    # Load the 20k open vocabulary:
    # Expects the original 20k vocab to be copied from
    # "csr_2_comp/13-34.1/wsj1/doc/lng_modl/vocab/wlist20o.nvp"
    # to "<data_dir>/vocab20ko.txt"
    vocab_file = os.path.join(args.data_dir, "vocab20ko.txt")
    with open(vocab_file, 'r') as fid:
        vocab = [l.strip().lower() for l in fid if l[0] != "#"]
    json_set_pieces(args, wsj, vocab)
예제 #2
0
def iamdb_pieces(args):
    iamdb = module_from_file("iamdb", os.path.join(root_dir, "datasets/iamdb.py"))
    forms = iamdb.load_metadata(args.data_dir, "▁")
    ds_keys = set()
    for _, v in iamdb.SPLITS.items():
        for ds in v:
            with open(os.path.join(args.data_dir, f"{ds}.txt"), "r") as fid:
                ds_keys.update(l.strip() for l in fid)

    # Train sentencepiece model only on the training set
    text = [l["text"] for _, lines in forms.items()
        for l in lines if l["key"] not in ds_keys]
    num_pieces = args.num_pieces
    sp = train_spm_model(
        iter(text),
        num_pieces + 1,  # to account for <unk>
        user_symbols=["/"],  # added so token is in the output set
    )
    vocab = sorted(set(w for t in text for w in t.split("▁") if w))
    assert 'MOVE' in vocab
    save_pieces(sp, num_pieces, args.output_prefix, vocab)
예제 #3
0
def train(world_rank, args):
    # setup logging
    level = logging.INFO
    if world_rank != 0:
        level = logging.CRITICAL
    logging.getLogger().setLevel(level)

    with open(args.config, "r") as fid:
        config = json.load(fid)
        logging.info("Using the config \n{}".format(json.dumps(config)))

    is_distributed_train = False
    if args.world_size > 1:
        is_distributed_train = True
        torch.distributed.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=world_rank,
        )

    if not args.disable_cuda:
        device = torch.device("cuda")
        torch.cuda.set_device(world_rank)
    else:
        device = torch.device("cpu")

    # seed everything:
    seed = config.get("seed", None)
    if seed is not None:
        torch.manual_seed(seed)

    # setup data loaders:
    logging.info("Loading dataset ...")
    dataset = config["data"]["dataset"]
    if not os.path.exists(f"datasets/{dataset}.py"):
        raise ValueError(f"Unknown dataset {dataset}")
    dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py")

    input_size = config["data"]["num_features"]
    data_path = config["data"]["data_path"]
    preprocessor = dataset.Preprocessor(
        data_path,
        num_features=input_size,
        tokens_path=config["data"].get("tokens", None),
        lexicon_path=config["data"].get("lexicon", None),
        use_words=config["data"].get("use_words", False),
        prepend_wordsep=config["data"].get("prepend_wordsep", False),
        supervised=config["data"].get("supervised", True),
        level=config["data"].get("level", "phone"))
    trainset = dataset.Dataset(data_path,
                               preprocessor,
                               split="train",
                               augment=True)
    valset = dataset.Dataset(data_path, preprocessor, split="validation")
    train_loader = utils.data_loader(trainset,
                                     config,
                                     world_rank,
                                     args.world_size,
                                     split='train')
    val_loader = utils.data_loader(valset,
                                   config,
                                   world_rank,
                                   args.world_size,
                                   split='validation')

    # setup criterion, model:
    logging.info("Loading model ...")
    criterion, output_size = models.load_criterion(
        config.get("criterion_type", "ctc"),
        preprocessor,
        config.get("criterion", {}),
    )
    criterion = criterion.to(device)
    model = models.load_model(config["model_type"], input_size, output_size,
                              config["model"]).to(device)

    if args.restore:
        models.load_from_checkpoint(model, criterion, args.checkpoint_path,
                                    True)

    n_params = sum(p.numel() for p in model.parameters())
    logging.info("Training {} model with {:,} parameters.".format(
        config["model_type"], n_params))

    # Store base module, criterion for saving checkpoints
    base_model = model
    base_criterion = criterion  # `decode` cannot be called on DDP module
    if is_distributed_train:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[world_rank])

        if len(list(criterion.parameters())) > 0:
            criterion = torch.nn.parallel.DistributedDataParallel(
                criterion, device_ids=[world_rank])

    epochs = config["optim"]["epochs"]
    lr = config["optim"]["learning_rate"]
    step_size = config["optim"]["step_size"]
    max_grad_norm = config["optim"].get("max_grad_norm", None)

    # run training:
    logging.info("Starting training ...")
    scale = 0.5**(args.last_epoch // step_size)
    params = [{
        "params": model.parameters(),
        "initial_lr": lr * scale,
        "lr": lr * scale
    }]
    if len(list(criterion.parameters())) > 0:
        crit_params = {"params": criterion.parameters()}
        crit_lr = config["optim"].get("crit_learning_rate", lr)
        crit_params['lr'] = crit_lr * scale
        crit_params['initial_lr'] = crit_lr * scale
        params.append(crit_params)

    optimizer = torch.optim.SGD(params)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=step_size,
        gamma=0.5,
        last_epoch=args.last_epoch,
    )
    min_val_loss = float("inf")
    min_val_cer = float("inf")
    min_val_wer = float("inf")

    Timer = utils.CudaTimer if device.type == "cuda" else utils.Timer
    timers = Timer([
        "ds_fetch",  # dataset sample fetch
        "model_fwd",  # model forward
        "crit_fwd",  # criterion forward
        "bwd",  # backward (model + criterion)
        "optim",  # optimizer step
        "metrics",  # viterbi, cer
        "train_total",  # total training
        "test_total",  # total testing
    ])
    num_updates = 0
    for epoch in range(args.last_epoch, epochs):
        logging.info("Epoch {} started. ".format(epoch + 1))
        model.train()
        criterion.train()
        start_time = time.time()
        meters = utils.Meters()
        timers.reset()
        timers.start("train_total").start("ds_fetch")
        for inputs, targets in train_loader:
            timers.stop("ds_fetch").start("model_fwd")
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            timers.stop("model_fwd").start("crit_fwd")
            loss = criterion(outputs, targets)
            timers.stop("crit_fwd").start("bwd")
            loss.backward()
            timers.stop("bwd").start("optim")
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    itertools.chain(model.parameters(),
                                    criterion.parameters()),
                    max_grad_norm,
                )
            optimizer.step()
            num_updates += 1
            timers.stop("optim").start("metrics")
            meters.loss += loss.item() * len(targets)
            meters.num_samples += len(targets)
            tokens_dist, words_dist, n_tokens, n_words = compute_edit_distance(
                base_criterion.viterbi(outputs), targets, preprocessor)
            meters.edit_distance_tokens += tokens_dist
            meters.num_tokens += n_tokens
            meters.edit_distance_words += words_dist
            meters.num_words += n_words
            timers.stop("metrics").start("ds_fetch")
        timers.stop("ds_fetch").stop("train_total")
        epoch_time = time.time() - start_time
        if args.world_size > 1:
            meters.sync()
        logging.info(
            "Epoch {} complete. "
            "nUpdates {}, Loss {:.3f}, CER {:.3f}, WER {:.3f},"
            " Time {:.3f} (s), LR {:.3f}".format(
                epoch + 1,
                num_updates,
                meters.avg_loss,
                meters.cer,
                meters.wer,
                epoch_time,
                scheduler.get_last_lr()[0],
            ), )
        logging.info("Evaluating validation set..")
        timers.start("test_total")
        val_loss, val_cer, val_wer = test(model,
                                          base_criterion,
                                          val_loader,
                                          preprocessor,
                                          device,
                                          args.world_size,
                                          checkpoint_path=args.checkpoint_path)
        timers.stop("test_total")
        if world_rank == 0:
            checkpoint(
                base_model,
                base_criterion,
                args.checkpoint_path,
                (val_cer < min_val_cer),
            )

            min_val_loss = min(val_loss, min_val_loss)
            min_val_cer = min(val_cer, min_val_cer)
            min_val_wer = min(val_wer, min_val_wer)
        logging.info(
            "Validation Set: Loss {:.3f}, CER {:.3f}, WER {:.3f}, "
            "Best Loss {:.3f}, Best CER {:.3f}, Best WER {:.3f}".format(
                val_loss, val_cer, val_wer, min_val_loss, min_val_cer,
                min_val_wer), )
        logging.info("Timing Info: " + ", ".join([
            "{} : {:.2f}ms".format(k, v * 1000.0)
            for k, v in timers.value().items()
        ]))
        scheduler.step()
        start_time = time.time()

    if is_distributed_train:
        torch.distributed.destroy_process_group()
예제 #4
0
def train(args, world_rank=0):
    # setup logging
    level = logging.INFO
    logging.getLogger().setLevel(level)

    if not args.disable_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    with open(args.config, "r") as fid:
        config = json.load(fid)
        logging.info("Using the config \n{}".format(json.dumps(config)))

    # seed everything:
    seed = config.get("seed", None)
    if seed is not None:
        torch.manual_seed(seed)

    # setup data loaders:
    logging.info("Loading dataset ...")
    dataset = config["data"]["dataset"]
    if not os.path.exists(f"datasets/{dataset}.py"):
        raise ValueError(f"Unknown dataset {dataset}")
    dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py")

    input_size = config["data"]["num_features"]
    data_path = config["data"]["data_path"]
    batch_size = config["optim"]["batch_size"]

    preprocessor = dataset.Preprocessor(
        data_path,
        num_features=input_size,
        tokens_path=config["data"].get("tokens", None),
        lexicon_path=config["data"].get("lexicon", None),
        use_words=config["data"].get("use_words", False),
        prepend_wordsep=config["data"].get("prepend_wordsep", False),
        supervised=config["data"].get("supervised", False),
        level=config["data"].get("level", "phone") 
    )    
    output_size = preprocessor.num_tokens
    trainset = dataset.Dataset(data_path, preprocessor, split="train", augment=True)
    valset = dataset.Dataset(data_path, preprocessor, split="test", augment=True)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    # setup criterion, model:
    logging.info("Loading model ...")
    criterion = nn.CrossEntropyLoss()
    
    model = PositionDependentUnigramBottleneck(input_size, output_size,config["model_type"], **config['model'])
        
    if args.restore:
        model_checkpoint = os.path.join(args.checkpoint_path, "model.checkpoint")
        model.load_state_dict(torch.load(model_checkpoint))

    n_params = sum(p.numel() for p in model.parameters())
    logging.info(
        "Training {} model with {:,} parameters.".format(config["model_type"], n_params)
    )

    # Store base module, criterion for saving checkpoints
    base_model = model
    if not isinstance(model, torch.nn.DataParallel):
      model = nn.DataParallel(model)

    epochs = config["optim"]["epochs"]
    lr = config["optim"]["learning_rate"]
    step_size = config["optim"]["step_size"]
    max_grad_norm = config["optim"].get("max_grad_norm", None)

    # run training:
    logging.info("Starting training ...")
    scale = 0.5 ** (args.last_epoch // step_size)
    params = [{"params" : model.parameters(),
               "initial_lr" : lr * scale,
               "lr" : lr * scale}]

    optimizer = torch.optim.SGD(params)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=step_size, gamma=0.5,
        last_epoch=args.last_epoch,
    )
    max_val_acc = float("-inf")

    Timer = utils.CudaTimer if device.type == "cuda" else utils.Timer
    timers = Timer(
        [
            "ds_fetch",  # dataset sample fetch
            "model_fwd",  # model forward
            "crit_fwd",  # criterion forward
            "bwd",  # backward (model + criterion)
            "optim",  # optimizer step
            "metrics",  # viterbi, cer
            "train_total",  # total training
            "test_total",  # total testing
        ]
    )
    num_updates = 0
    for epoch in range(args.last_epoch, epochs):
        logging.info("Epoch {} started. ".format(epoch + 1))
        model.train()
        start_time = time.time()
        meters = utils.IBMeters()
        
        timers.reset()
        timers.start("train_total").start("ds_fetch")
        for i, (inputs, targets, input_masks) in enumerate(train_loader):
            if args.eval_only:
                continue
            timers.stop("ds_fetch").start("model_fwd")
            optimizer.zero_grad()
            in_scores, trg_scores = model(inputs.to(device), input_masks)
            prediction_loss = criterion(trg_scores, targets.to(device))
            loss, I_ZX, I_ZY = model.module.calculate_loss(in_scores, prediction_loss)
            
            timers.stop("model_fwd").start("crit_fwd")
            timers.stop("crit_fwd").start("bwd")
            loss.backward()
            timers.stop("bwd").start("optim")
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_grad_norm,
                )
            optimizer.step()
            num_updates += 1
            timers.stop("optim").start("metrics")
            meters.loss += loss.item() * len(targets)
            meters.num_samples += len(targets)
            meters.I_ZX += I_ZX.item() * len(targets)
            meters.I_ZY += I_ZY.item() * len(targets)
            meters.num_tokens += len(targets)
            timers.stop("metrics").start("ds_fetch")
            if i % 1000 == 0:
                info = 'Itr {} {meters.loss:.3f} ({meters.avg_loss:.3f})'.format(i, meters=meters)
                print(info)
        timers.stop("ds_fetch").stop("train_total")
        epoch_time = time.time() - start_time

        logging.info(
            "Epoch {} complete. "
            "nUpdates {}, Loss {:.3f} "
            " Time {:.3f} (s), LR {:.3f}".format(
                epoch + 1,
                num_updates,
                meters.avg_loss,
                epoch_time,
                scheduler.get_last_lr()[0],
            ),
        )

        if epoch % 1 == 0:
          logging.info("Evaluating validation set..")
          timers.start("test_total")
          val_acc = test(
             model, val_loader, device, args.checkpoint_path
          )
          
          timers.stop("test_total")
          checkpoint(
                  base_model,
                  args.checkpoint_path,
                  (val_acc > max_val_acc),
          )

          max_val_acc = max(val_acc, max_val_acc) 
          logging.info(
              "Validation Set: WER {:.3f}, "
              "Best WER {:.3f}".format(
                1-val_acc, 1-max_val_acc
            ),
          )
          logging.info(
              "Timing Info: "
              + ", ".join(
                  [
                      "{} : {:.2f}ms".format(k, v * 1000.0)
                      for k, v in timers.value().items()
                  ]
              )
          )
        scheduler.step()
        start_time = time.time()
예제 #5
0
def test(args):
    with open(args.config, "r") as fid:
        config = json.load(fid)

    if not args.disable_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    dataset = config["data"]["dataset"]
    if not os.path.exists(f"datasets/{dataset}.py"):
        raise ValueError(f"Unknown dataset {dataset}")
    dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py")

    input_size = config["data"]["num_features"]
    data_path = config["data"]["data_path"]
    preprocessor = dataset.Preprocessor(
        data_path,
        num_features=input_size,
        tokens_path=config["data"].get("tokens", None),
        lexicon_path=config["data"].get("lexicon", None),
        use_words=config["data"].get("use_words", False),
        prepend_wordsep=config["data"].get("prepend_wordsep", False),
    )
    data = dataset.Dataset(data_path, preprocessor, split=args.split)
    loader = utils.data_loader(data, config)

    criterion, output_size = models.load_criterion(
        config.get("criterion_type", "ctc"),
        preprocessor,
        config.get("criterion", {}),
    )
    criterion = criterion.to(device)
    model = models.load_model(config["model_type"], input_size, output_size,
                              config["model"]).to(device)
    models.load_from_checkpoint(model, criterion, args.checkpoint_path,
                                args.load_last)

    model.eval()
    meters = utils.Meters()
    for inputs, targets in loader:
        outputs = model(inputs.to(device))
        meters.loss += criterion(outputs, targets).item() * len(targets)
        meters.num_samples += len(targets)
        predictions = criterion.viterbi(outputs)
        for p, t in zip(predictions, targets):
            p, t = preprocessor.tokens_to_text(p), preprocessor.to_text(t)
            pw, tw = p.split(preprocessor.wordsep), t.split(
                preprocessor.wordsep)
            pw, tw = list(filter(None, pw)), list(filter(None, tw))
            tokens_dist = editdistance.eval(p, t)
            words_dist = editdistance.eval(pw, tw)
            print("CER: {:.3f}".format(tokens_dist * 100.0 /
                                       len(t) if len(t) > 0 else 0))
            print("WER: {:.3f}".format(words_dist * 100.0 /
                                       len(tw) if len(tw) > 0 else 0))
            print("HYP:", "".join(p))
            print("REF", "".join(t))
            print("=" * 80)
            meters.edit_distance_tokens += tokens_dist
            meters.edit_distance_words += words_dist
            meters.num_tokens += len(t)
            meters.num_words += len(tw)

    print("Loss {:.3f}, CER {:.3f}, WER {:.3f}, ".format(
        meters.avg_loss, meters.cer, meters.wer))
예제 #6
0
def librispeech_pieces(args):
    # Train sentencepiece model only on the training set
    librispeech = module_from_file("librispeech", os.path.join(root_dir, "datasets/librispeech.py"))
    json_set_pieces(args, librispeech)