コード例 #1
0
ファイル: remove_epsilon_test.py プロジェクト: yyht/k2
 def test1(self):
     s = '''
         0 4 1 1
         0 1 1 1
         1 2 0 2
         1 3 0 3
         1 4 0 2
         2 7 0 4
         3 7 0 5
         4 6 1 2
         4 6 0 3
         4 8 1 3
         4 9 -1 2
         5 9 -1 4
         6 9 -1 3
         7 9 -1 5
         8 9 -1 6
         9
     '''
     fsa = k2.Fsa.from_str(s)
     prop = fsa.properties
     self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE)
     dest = k2.remove_epsilon(fsa)
     prop = dest.properties
     self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE)
     log_semiring = False
     self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))
コード例 #2
0
    def test1(self):
        if not torch.cuda.is_available():
            return

        if not k2.with_cuda:
            return

        device = torch.device('cuda', 0)
        s = '''
            0 1 0 1 1
            1 2 0 2 1
            2 3 0 3 1
            3 4 4 4 1
            3 5 -1 5 1
            4 5 -1 6 1
            5
        '''
        fsa = k2.Fsa.from_str(s, aux_label_names=['foo']).to(device)
        filler = 2
        fsa.foo_filler = filler
        print("Before removing epsilons: ", fsa)
        prop = fsa.properties
        self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE)
        dest = k2.remove_epsilon(fsa)
        prop = dest.properties
        self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE)
        log_semiring = False
        self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))

        print("After removing epsilons: ", dest)
        assert torch.where(dest.foo.values == filler)[0].numel() == 0
コード例 #3
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
コード例 #4
0
ファイル: remove_epsilon_test.py プロジェクト: jimbozhang/k2
    def test_autograd(self):
        if not torch.cuda.is_available():
            return

        if not k2.with_cuda:
            return

        device = torch.device('cuda', 0)
        s = '''
            0 1 0 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        fsa = k2.Fsa.from_str(s).to(device).requires_grad_(True)
        ans = k2.remove_epsilon(fsa)
        print("ans = ", ans)
        # arc map is [[1] [0 2] [2]]
        scale = torch.tensor([10, 20, 30]).to(device)

        (ans.scores * scale).sum().backward()
        expected_grad = torch.empty_like(fsa.scores)
        expected_grad[0] = scale[1]
        expected_grad[1] = scale[0]
        expected_grad[2] = scale[1] + scale[2]
        print("fsa.grad = ", fsa.grad)
        print("expected_grad = ", expected_grad)
        assert torch.all(torch.eq(fsa.grad, expected_grad))
コード例 #5
0
ファイル: remove_epsilon_test.py プロジェクト: jimbozhang/k2
    def test1(self):
        if not torch.cuda.is_available():
            return

        if not k2.with_cuda:
            return

        device = torch.device('cuda', 0)
        s = '''
            0 1 0 1 1
            1 2 0 2 1
            2 3 0 3 1
            3 4 4 4 1
            3 5 -1 5 1
            4 5 -1 6 1
            5
        '''
        fsa = k2.Fsa.from_str(s, num_aux_labels=1).to(device)
        print(fsa.aux_labels)
        prop = fsa.properties
        self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE)
        dest = k2.remove_epsilon(fsa)
        prop = dest.properties
        self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE)
        log_semiring = False
        self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))

        # just make sure that it runs.
        dest2 = k2.remove_epsilon_and_add_self_loops(fsa)
        dest3 = k2.remove_epsilon(dest2)

        self.assertTrue(
            k2.is_rand_equivalent(dest,
                                  dest3,
                                  log_semiring,
                                  treat_epsilons_specially=False))
        self.assertFalse(
            k2.is_rand_equivalent(dest,
                                  dest2,
                                  log_semiring,
                                  treat_epsilons_specially=False,
                                  npath=10000))
        self.assertTrue(
            k2.is_rand_equivalent(dest,
                                  dest2,
                                  log_semiring,
                                  treat_epsilons_specially=True))
コード例 #6
0
    def test_autograd(self):
        s = '''
            0 1 0 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        src = k2.Fsa.from_str(s).requires_grad_(True)
        scores_copy = src.scores.detach().clone().requires_grad_(True)

        src.attr1 = "hello"
        src.attr2 = "k2"
        float_attr = torch.tensor([0.1, 0.2, 0.3],
                                  dtype=torch.float32,
                                  requires_grad=True)

        src.float_attr = float_attr.detach().clone().requires_grad_(True)
        src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32)
        src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]])

        dest = k2.remove_epsilon(src)
        # arc map is [[1] [0 2] [2]]

        assert dest.attr1 == src.attr1
        assert dest.attr2 == src.attr2

        expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]])
        assert dest.int_attr == expected_int_attr

        expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20, 60, 70],
                                                [60, 70]])
        assert dest.ragged_attr == expected_ragged_attr

        expected_float_attr = torch.empty_like(dest.float_attr)
        expected_float_attr[0] = float_attr[1]
        expected_float_attr[1] = float_attr[0] + float_attr[2]
        expected_float_attr[2] = float_attr[2]

        assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

        expected_scores = torch.empty_like(dest.scores)
        expected_scores[0] = scores_copy[1]
        expected_scores[1] = scores_copy[0] + scores_copy[2]
        expected_scores[2] = scores_copy[2]

        assert torch.all(torch.eq(dest.scores, expected_scores))

        scale = torch.tensor([10, 20, 30]).to(float_attr)

        (dest.float_attr * scale).sum().backward()
        (expected_float_attr * scale).sum().backward()
        assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

        (dest.scores * scale).sum().backward()
        (expected_scores * scale).sum().backward()
        assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
コード例 #7
0
ファイル: remove_epsilon_test.py プロジェクト: yyht/k2
    def test_autograd(self):
        s = '''
            0 1 0 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        fsa = k2.Fsa.from_str(s).requires_grad_(True)
        ans = k2.remove_epsilon(fsa)
        # arc map is [[1] [0 2] [2]]
        scale = torch.tensor([10, 20, 30])

        (ans.scores * scale).sum().backward()
        expected_grad = torch.empty_like(fsa.scores)
        expected_grad[0] = scale[1]
        expected_grad[1] = scale[0]
        expected_grad[2] = scale[1] + scale[2]
        assert torch.all(torch.eq(fsa.grad, expected_grad))
コード例 #8
0
    def test_composition_equivalence(self):
        index = _generate_fsa_vec()
        index = k2.arc_sort(k2.connect(k2.remove_epsilon(index)))

        src = _generate_fsa_vec()

        replace = k2.replace_fsa(src, index, 1)
        replace = k2.top_sort(replace)

        f_fsa = _construct_f(src)
        f_fsa = k2.arc_sort(f_fsa)
        intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True)
        intersect = k2.invert(intersect)
        intersect = k2.top_sort(intersect)
        delattr(intersect, 'aux_labels')

        assert k2.is_rand_equivalent(replace,
                                     intersect,
                                     log_semiring=True,
                                     delta=1e-3)
コード例 #9
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
コード例 #10
0
 def test_random(self):
     while True:
         fsa = k2.random_fsa(max_symbol=20,
                             min_num_arcs=50,
                             max_num_arcs=500)
         fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa)))
         prob = fsa.properties
         # we need non-deterministic fsa
         if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC:
             break
     log_semiring = False
     # test weight pushing tropical
     dest_max = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3))
     # test weight pushing log
     dest_log = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
コード例 #11
0
ファイル: nbest.py プロジェクト: k2-fsa/k2
    def intersect(self, lats: Fsa) -> 'Nbest':
        '''Intersect this Nbest object with a lattice and get 1-best
        path from the resulting FsaVec.

        Caution:
          We assume FSAs in `self.fsa` don't have epsilon self-loops.
          We also assume `self.fsa.labels` and `lats.labels` are token IDs.

        Args:
          lats:
            An FsaVec. It can be the return value of
            :func:`whole_lattice_rescoring`.
        Returns:
          Return a new Nbest. This new Nbest shares the same shape with `self`,
          while its `fsa` is the 1-best path from intersecting `self.fsa` and
          `lats.
        '''
        assert self.fsa.device == lats.device, \
                f'{self.fsa.device} vs {lats.device}'
        assert len(lats.shape) == 3, f'{lats.shape}'
        assert lats.arcs.dim0() == self.shape.dim0(), \
                f'{lats.arcs.dim0()} vs {self.shape.dim0()}'

        lats = k2.arc_sort(lats)  # no-op if lats is already arc sorted

        fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)

        path_to_seq_map = self.shape.row_ids(1)

        ans_lats = k2.intersect_device(a_fsas=lats,
                                       b_fsas=fsas_with_epsilon_loops,
                                       b_to_a_map=path_to_seq_map,
                                       sorted_match_a=True)

        one_best = k2.shortest_path(ans_lats, use_double_scores=True)

        one_best = k2.remove_epsilon(one_best)

        return Nbest(fsa=one_best, shape=self.shape)
コード例 #12
0
ファイル: graph.py プロジェクト: zhichaowang/snowfall
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int,
                aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        H:  An ``Fsa`` that represents a specific topology used to convert the network
            outputs to a sequence of phones.
            Typically, it's a CTC topology fst, in which when 0 appears on the left
            side, it represents the blank symbol; when it appears on the right side,
            it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    HLG = k2.compose(H, LG, inner_labels='phones')

    logging.info("Connecting LG")
    HLG = k2.connect(HLG)

    logging.info("Arc sorting LG")
    HLG = k2.arc_sort(HLG)
    logging.info(
        f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )

    # Attach a new attribute `lm_scores` so that we can recover
    # the `am_scores` later.
    # The scores on an arc consists of two parts:
    #  scores = am_scores + lm_scores
    # NOTE: we assume that both kinds of scores are in log-space.
    HLG.lm_scores = HLG.scores.clone()
    return HLG
コード例 #13
0
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate
    use_pruned_intersect = args.use_pruned_intersect

    fix_random_seed(42)
    if world_size > 1:
        setup_dist(rank, world_size, args.master_port)

    suffix = ''
    if args.context_window is not None and args.context_window > 0:
        suffix = f'ac{args.context_window}'
    giga_subset = f'giga{args.subset}'
    exp_dir = Path(
        f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}')

    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    if not Path(lang_dir / f'P_{args.subset}.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P_{args.subset}.fst.txt')
        with open(lang_dir / f'P_{args.subset}.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / f'P_{args.subset}.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / f'P_{args.subset}.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()

    gigaspeech = GigaSpeechAsrDataModule(args)
    train_dl = gigaspeech.train_dataloaders()
    valid_dl = gigaspeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    if use_pruned_intersect:
        logging.info('Use pruned intersect for den_lats')
    else:
        logging.info("Don't use pruned intersect for den_lats")

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True,
            is_espnet_structure=True)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    if args.torchscript:
        logging.info('Applying TorchScript to model...')
        model = torch.jit.script(model)

    model.to(device)
    describe(model)

    if world_size > 1:
        model = DDP(model, device_ids=[rank])

    # Now for the alignment model, if any
    if args.use_ali_model:
        ali_model = TdnnLstm1b(
            num_features=80,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4)

        ali_model_fname = Path(
            f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
        assert ali_model_fname.is_file(), \
                f'ali model filename {ali_model_fname} does not exist!'
        ali_model.load_state_dict(
            torch.load(ali_model_fname, map_location='cpu')['state_dict'])
        ali_model.to(device)

        ali_model.eval()
        ali_model.requires_grad_(False)
        logging.info(f'Use ali_model: {ali_model_fname}')
    else:
        ali_model = None
        logging.info('No ali_model')

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=args.lr_factor,
                     warm_step=args.warm_step,
                     weight_decay=args.weight_decay)

    scaler = GradScaler(enabled=args.amp)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scaler=scaler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            ali_model=ali_model,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
            scaler=scaler)
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            scaler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank,
                            torchscript=args.torchscript_epoch != -1
                            and epoch >= args.torchscript_epoch)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        scaler=scaler,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank,
                        torchscript=args.torchscript_epoch != -1
                        and epoch >= args.torchscript_epoch)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    if world_size > 1:
        torch.distributed.barrier()
        cleanup_dist()
コード例 #14
0
def main():
    args = get_parser().parse_args()
    print('World size:', args.world_size, 'Rank:', args.local_rank)
    setup_dist(rank=args.local_rank,
               world_size=args.world_size,
               master_port=args.master_port)
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 10
    use_adam = True

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan'
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = args.local_rank
    device = torch.device('cuda', device_id)
    phone_ids = lexicon.phone_symbols()

    if not Path(lang_dir / 'P.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P.fst.txt')
        with open(lang_dir / 'P.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / 'P.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / 'P.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    train = K2SpeechRecognitionDataset(cuts_train,
                                       cut_transforms=[
                                           CutConcatenate(),
                                           CutMix(cuts=cuts_musan,
                                                  prob=0.5,
                                                  snr=(10, 20))
                                       ])
    train_sampler = SingleCutSampler(
        cuts_train,
        max_frames=40000,
        shuffle=True,
    )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           sampler=train_sampler,
                                           batch_size=None,
                                           num_workers=4)
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    if use_adam:
        learning_rate = 1e-3
        weight_decay = 5e-4
        optimizer = optim.AdamW(model.parameters(),
                                lr=learning_rate,
                                weight_decay=weight_decay)
        # Equivalent to the following in the epoch loop:
        #  if epoch > 6:
        #      curr_learning_rate *= 0.8
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6))
    else:
        learning_rate = 5e-5
        weight_decay = 1e-5
        momentum = 0.9
        lr_schedule_gamma = 0.7
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer, gamma=lr_schedule_gamma)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scheduler=lr_scheduler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        # LR scheduler can hold multiple learning rates for multiple parameter groups;
        # For now we report just the first LR which we assume concerns most of the parameters.
        curr_learning_rate = lr_scheduler.get_last_lr()[0]
        tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                             global_batch_idx_train)
        tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )

        lr_scheduler.step()

        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            optimizer=None,
                            scheduler=None,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=lr_scheduler,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
コード例 #15
0
    def test_autograd(self):
        if not torch.cuda.is_available():
            return

        if not k2.with_cuda:
            return

        devices = [torch.device('cuda', 0)]
        if torch.cuda.device_count() > 1:
            torch.cuda.set_device(1)
            devices.append(torch.device('cuda', 1))

        s = '''
            0 1 0 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        for device in devices:
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            scores_copy = src.scores.detach().clone().requires_grad_(True)

            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)

            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50],
                                               [60, 70]]).to(device)

            dest = k2.remove_epsilon(src)
            # arc map is [[1] [0 2] [2]]

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]).to(device)
            assert dest.int_attr == expected_int_attr

            expected_ragged_attr = k2.RaggedTensor([[30, 40, 50],
                                                    [10, 20, 60, 70],
                                                    [60, 70]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr[1]
            expected_float_attr[1] = float_attr[0] + float_attr[2]
            expected_float_attr[2] = float_attr[2]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy[1]
            expected_scores[1] = scores_copy[0] + scores_copy[2]
            expected_scores[2] = scores_copy[2]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
コード例 #16
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo:  CTC topology fst, in which when 0 appears on the left side, it represents
                   the blank symbol; when it appears on the right side,
                   it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    LG = k2.compose(ctc_topo, LG, inner_labels='phones')

    logging.info("Connecting LG")
    LG = k2.connect(LG)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)
    logging.info(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
コード例 #17
0
def main():
    args = get_parser().parse_args()
    print('World size:', args.world_size, 'Rank:', args.local_rank)
    setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port)
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 10
    use_adam = True

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist'
    setup_logger('{}/log/log-train'.format(exp_dir), use_console=args.local_rank == 0)
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = args.local_rank
    device = torch.device('cuda', device_id)
    phone_ids = lexicon.phone_symbols()

    if not Path(lang_dir / 'P.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P.fst.txt')
        with open(lang_dir / 'P.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / 'P.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / 'P.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )


    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir /
                                  'cuts_train-clean-100.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
    if not args.bucketing_sampler:
        # We don't mix concatenating the cuts and bucketing
        # Here we insert concatenation before mixing so that the
        # noises from Musan are mixed onto almost-zero-energy
        # padding frames.
        transforms = [CutConcatenate(duration_factor=1)] + transforms
    train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms)
    if args.bucketing_sampler:
        logging.info('Using BucketingSampler.')
        train_sampler = BucketingSampler(
            cuts_train,
            max_frames=40000,
            shuffle=True,
            num_buckets=30
        )
    else:
        logging.info('Using regular sampler with cut concatenation.')
        train_sampler = SingleCutSampler(
            cuts_train,
            max_frames=30000,
            shuffle=True,
        )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=4
    )
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    # Note: we explicitly set world_size to 1 to disable the auto-detection of
    #       distributed training inside the sampler. This way, every GPU will
    #       perform the computation on the full dev set. It is a bit wasteful,
    #       but unfortunately loss aggregation between multiple processes with
    #       torch.distributed.all_reduce() tends to hang indefinitely inside
    #       NCCL after ~3000 steps. With the current approach, we can still report
    #       the loss on the full validation set.
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000, world_size=1, rank=0)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(
        validate,
        sampler=valid_sampler,
        batch_size=None,
        num_workers=1
    )

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    model = TdnnLstm1b(num_features=80,
                       num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
                       subsampling_factor=3)

    model.to(device)
    describe(model)

    if use_adam:
        learning_rate = 1e-3
        weight_decay = 5e-4
        optimizer = optim.AdamW(model.parameters(),
                                lr=learning_rate,
                                weight_decay=weight_decay)
        # Equivalent to the following in the epoch loop:
        #  if epoch > 6:
        #      curr_learning_rate *= 0.8
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda ep: 1.0 if ep < 7 else 0.8 ** (ep - 6)
        )
    else:
        learning_rate = 5e-5
        weight_decay = 1e-5
        momentum = 0.9
        lr_schedule_gamma = 0.7
        optimizer = optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            weight_decay=weight_decay
        )
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer,
            gamma=lr_schedule_gamma
        )

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}")

    if args.world_size > 1:
        logging.info('Using DistributedDataParallel in training. '
                     'The reported loss, num_frames, etc. for training steps include '
                     'only the batches seen in the master process (the actual loss '
                     'includes batches from all GPUs, and the actual num_frames is '
                     f'approx. {args.world_size}x larger.')
        # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather...
        # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)

        # LR scheduler can hold multiple learning rates for multiple parameter groups;
        # For now we report just the first LR which we assume concerns most of the parameters.
        curr_learning_rate = lr_scheduler.get_last_lr()[0]
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )

        lr_scheduler.step()

        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            optimizer=None,
                            scheduler=None,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            local_rank=args.local_rank,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=lr_scheduler,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        local_rank=args.local_rank,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
    cleanup_dist()