Example #1
0
    def validate(self):
        kaldi.InstantiateKaldiCuda()
        chain_opts = self.chain_opts
        den_fst_path = os.path.join(chain_opts.dir, "den.fst")

        #           load model
        model = self.initialize_model()
        self.load_base_model(model)
        model.eval()

        training_opts = kaldi.chain.CreateChainTrainingOptions(
            chain_opts.l2_regularize,
            chain_opts.out_of_range_regularize,
            chain_opts.leaky_hmm_coefficient,
            chain_opts.xent_regularize,
        )
        compute_chain_objf(
            model,
            chain_opts.egs,
            den_fst_path,
            training_opts,
            minibatch_size="1:64",
            left_context=chain_opts.context,
            right_context=chain_opts.context,
            use_ivector=True if self.chain_opts.ivector_dir else False)
Example #2
0
    def validate(self):
        kaldi.InstantiateKaldiCuda()
        chain_opts = self.chain_opts
        den_fst_path = os.path.join(chain_opts.dir, "den.fst")

        #           load model
        model = self.Net(self.chain_opts.feat_dim, self.chain_opts.output_dim)
        model.load_state_dict(torch.load(chain_opts.base_model))
        model.eval()

        training_opts = kaldi.chain.CreateChainTrainingOptions(
            chain_opts.l2_regularize,
            chain_opts.out_of_range_regularize,
            chain_opts.leaky_hmm_coefficient,
            chain_opts.xent_regularize,
        )
        compute_chain_objf(
            model,
            chain_opts.egs,
            den_fst_path,
            training_opts,
            minibatch_size="1:64",
            left_context=chain_opts.context,
            right_context=chain_opts.context,
        )
Example #3
0
    def train(self):
        """Run one iteration of LF-MMI training

        This is called by 
        >>> self.train() 

        It will probably be renamed as self.fit() since this seems to be
        the standard way other libraries call the training function.
        """
        kaldi.InstantiateKaldiCuda()
        chain_opts = self.chain_opts
        lr = chain_opts.lr
        den_fst_path = os.path.join(chain_opts.dir, "den.fst")

        #           load model
        model = self.initialize_model()
        self.load_base_model(model)

        training_opts = kaldi.chain.CreateChainTrainingOptions(
            chain_opts.l2_regularize,
            chain_opts.out_of_range_regularize,
            chain_opts.leaky_hmm_coefficient,
            chain_opts.xent_regularize,
        )
        context = chain_opts.context
        model = model.cuda()
        optimizer = self.get_optimizer(
            model,
            lr=chain_opts.lr,
            weight_decay=chain_opts.l2_regularize_factor)
        new_model = train_lfmmi_one_iter(
            model,
            chain_opts.egs,
            den_fst_path,
            training_opts,
            chain_opts.feat_dim,
            minibatch_size=chain_opts.minibatch_size,
            left_context=context,
            right_context=context,
            lr=chain_opts.lr,
            weight_decay=chain_opts.l2_regularize_factor,
            frame_shift=chain_opts.frame_shift,
            optimizer=optimizer,
            e2e=True)
        torch.save(new_model.state_dict(), chain_opts.new_model)
Example #4
0
    def combine_final_model(self):
        """Implements Kaldi-style model ensembling"""
        kaldi.InstantiateKaldiCuda()
        chain_opts = self.chain_opts
        den_fst_path = os.path.join(chain_opts.dir, "den.fst")
        base_models = chain_opts.base_model.split(',')
        assert len(base_models) > 0
        training_opts = kaldi.chain.CreateChainTrainingOptions(
            chain_opts.l2_regularize,
            chain_opts.out_of_range_regularize,
            chain_opts.leaky_hmm_coefficient,
            chain_opts.xent_regularize,
        )

        moving_average = self.Net(self.chain_opts.feat_dim,
                                  self.chain_opts.output_dim)
        best_mdl = self.Net(self.chain_opts.feat_dim,
                            self.chain_opts.output_dim)
        moving_average.load_state_dict(torch.load(base_models[0]))
        moving_average.cuda()
        best_mdl = moving_average
        compute_objf = lambda mdl: compute_chain_objf(
            mdl,
            chain_opts.egs,
            den_fst_path,
            training_opts,
            minibatch_size="1:64",  # TODO: this should come from a config
            left_context=chain_opts.context,
            right_context=chain_opts.context,
            frame_shift=chain_opts.frame_shift,
        )

        _, init_objf = compute_objf(moving_average)
        best_objf = init_objf

        model_acc = dict(moving_average.named_parameters())
        num_accumulated = torch.Tensor([1.0]).reshape(1).cuda()
        best_num_to_combine = 1
        for mdl_name in base_models[1:]:
            this_mdl = self.Net(self.chain_opts.feat_dim,
                                self.chain_opts.output_dim)
            logging.info("Combining model {}".format(mdl_name))
            this_mdl.load_state_dict(torch.load(mdl_name))
            this_mdl = this_mdl.cuda()
            # TODO(srikanth): check why is this even necessary
            moving_average.cuda()
            num_accumulated += 1.
            for name, params in this_mdl.named_parameters():
                model_acc[name].data.mul_(
                    (num_accumulated - 1.) / (num_accumulated))
                model_acc[name].data.add_(
                    params.data.mul_(1. / num_accumulated))
            _, this_objf = compute_objf(moving_average)
            if this_objf > best_objf:
                best_objf = this_objf
                best_mdl = moving_average
                best_num_to_combine = num_accumulated.clone().detach()
                logging.info("Found best model")
            else:
                logging.info("Won't update best model")

        logging.info("Combined {} models".format(best_num_to_combine.cpu()))
        logging.info("Initial objf = {}, Final objf = {}".format(
            init_objf, best_objf))
        best_mdl.cpu()
        torch.save(best_mdl.state_dict(), chain_opts.new_model)
        return self
Example #5
0
def train_lfmmi_one_iter(
    model,
    egs_file,
    den_fst_path,
    training_opts,
    feat_dim,
    minibatch_size="64",
    use_gpu=True,
    lr=0.0001,
    weight_decay=0.25,
    frame_shift=0,
    left_context=0,
    right_context=0,
    print_interval=10,
    frame_subsampling_factor=3,
):
    """Run one iteration of LF-MMI training

    The function loads the latest model, takes a list of egs, path to denominator
    fst and runs through the merged egs for one iteration of training. This is 
    similar to how one iteration of training is completed in Kaldi.

    Args:
        model: Path to pytorch model (.pt file)
        egs_file: scp or ark file (a string), should be prefix accordingly just like Kaldi
        den_fst_path: path to den.fst file
        training_opts: options of type ChainTrainingOpts
        feat_dim: dimension of features (e.g. 40 for MFCC hires features)
        minibatch_size: a string of minibatch sizes separated by commas. E.g "64" or "128,64"
        use_gpu: a boolean to set or unset the use of GPUs while training
        lr: learning rate
        frame_shift: an integer (usually 0, 1, or 2) used to shift the training features
        print_interval: the interval (a positive integer) to print the loss value

    Returns:
        updated model in CPU
    """
    # this is required to make sure Kaldi uses GPU
    kaldi.InstantiateKaldiCuda()
    if training_opts is None:
        training_opts = kaldi.chain.CreateChainTrainingOptionsDefault()
    den_graph = kaldi.chain.LoadDenominatorGraph(den_fst_path,
                                                 model.output_dim)
    criterion = KaldiChainObjfFunction.apply
    if use_gpu:
        model = model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    acc_sum = torch.tensor(0., requires_grad=False)
    for mb_id, merged_egs in enumerate(
            prepare_minibatch(egs_file, minibatch_size)):
        chunk_size = kaldi.chain.GetFramesPerSequence(
            merged_egs) * frame_subsampling_factor
        features = kaldi.chain.GetFeaturesFromEgs(merged_egs)
        features = features[:, frame_shift:frame_shift + chunk_size +
                            left_context + right_context, :]
        features = features.cuda()
        output, xent_output = model(features)
        sup = kaldi.chain.GetSupervisionFromEgs(merged_egs)
        deriv = criterion(training_opts, den_graph, sup, output, xent_output)
        acc_sum.add_(deriv[0])
        if mb_id > 0 and mb_id % print_interval == 0:
            logging.info("Overall objf={}\n".format(acc_sum / print_interval))
            acc_sum.zero_()
        optimizer.zero_grad()
        deriv.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()
    model = model.cpu()
    return model