예제 #1
0
def test_ctc_loss_two():

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")

    # gives cost inf
    #probs = torch.FloatTensor([[[0, 0, 0, 0, 1000]]]).transpose(0, 1).contiguous()

    # gives cost 0
    # probs = torch.FloatTensor([[[0, 100, 0, 0, 0]]]).transpose(0, 1).contiguous()

    ## Everything equally likely: gives cost 1.6094
    #probs = torch.FloatTensor([[[0, 0, 0, 0, 0]]]).transpose(0, 1).contiguous()

    # Everything equally likely: gives cost 1.6094
    probs = torch.FloatTensor([[[1, 1, 1, 1, 1]]]).transpose(0, 1).contiguous()

    print("probs.size(): " + str(probs.size()))
    labels = Variable(torch.IntTensor([1]))
    label_sizes = Variable(torch.IntTensor([1]))
    probs_sizes = Variable(torch.IntTensor([1]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]),
                          lr=0.001,
                          momentum=0.9,
                          weight_decay=1e-5)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))
예제 #2
0
    def __init__(
        self,
        odim: int,
        encoder_output_sizse: int,
        dropout_rate: float = 0.0,
        ctc_type: str = "builtin",
        reduce: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        eprojs = encoder_output_sizse
        self.dropout_rate = dropout_rate
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        self.ctc_type = ctc_type

        if self.ctc_type == "builtin":
            reduction_type = "sum" if reduce else "none"
            self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
        elif self.ctc_type == "warpctc":
            import warpctc_pytorch as warp_ctc

            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
            )

        self.reduce = reduce
예제 #3
0
    def __init__(self,
                 odim,
                 eprojs,
                 dropout_rate,
                 ctc_type='warpctc',
                 reduce=True):
        super(CTC, self).__init__()
        self.dropout_rate = dropout_rate
        self.loss = None
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        self.ctc_type = ctc_type

        if self.ctc_type == 'builtin':
            reduction_type = 'sum' if reduce else 'none'
            self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
        elif self.ctc_type == 'warpctc':
            import warpctc_pytorch as warp_ctc
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                'ctc_type must be "builtin" or "warpctc": {}'.format(
                    self.ctc_type))

        self.ignore_id = -1
        self.reduce = reduce
예제 #4
0
파일: test_loss.py 프로젝트: zpppy/espnet
def test_ctc_loss():
    pytest.importorskip("torch")
    pytest.importorskip("warpctc_pytorch")
    import torch
    import warpctc_pytorch

    from espnet.nets.e2e_asr_th import pad_list

    n_out = 7
    input_length = numpy.array([11, 17, 15], dtype=numpy.int32)
    label_length = numpy.array([4, 2, 3], dtype=numpy.int32)
    np_pred = [
        numpy.random.rand(il, n_out).astype(numpy.float32)
        for il in input_length
    ]
    np_target = [
        numpy.random.randint(0, n_out, size=ol, dtype=numpy.int32)
        for ol in label_length
    ]

    # NOTE: np_pred[i] seems to be transposed and used axis=-1 in e2e_asr.py
    ch_pred = F.separate(F.pad_sequence(np_pred), axis=-2)
    ch_target = F.pad_sequence(np_target, padding=-1)
    ch_loss = F.connectionist_temporal_classification(ch_pred, ch_target, 0,
                                                      input_length,
                                                      label_length).data

    th_pred = pad_list([torch.from_numpy(x) for x in np_pred],
                       0.0).transpose(0, 1)
    th_target = torch.from_numpy(numpy.concatenate(np_target))
    th_ilen = torch.from_numpy(input_length)
    th_olen = torch.from_numpy(label_length)
    th_loss = warpctc_pytorch.CTCLoss(size_average=True)(
        th_pred, th_target, th_ilen, th_olen).data.numpy()[0]
    numpy.testing.assert_allclose(th_loss, ch_loss, 0.05)
예제 #5
0
 def __init__(self, odim, eprojs, dropout_rate):
     super(CTC, self).__init__()
     self.dropout_rate = dropout_rate
     self.loss = None
     self.ctc_lo = torch.nn.Linear(eprojs, odim)
     self.loss_fn = warp_ctc.CTCLoss(size_average=True)
     self.ignore_id = -1
예제 #6
0
파일: ctc.py 프로젝트: limjoo14/ASR_project
    def __init__(self,
                 odim,
                 eprojs,
                 dropout_rate,
                 ctc_type="warpctc",
                 reduce=True):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.loss = None
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        self.probs = None  # for visualization

        # In case of Pytorch >= 1.7.0, CTC will be always builtin
        self.ctc_type = (ctc_type if LooseVersion(torch.__version__) <
                         LooseVersion("1.7.0") else "builtin")
        if ctc_type != self.ctc_type:
            logging.warning(
                f"CTC was set to {self.ctc_type} due to PyTorch version.")
        if self.ctc_type == "builtin":
            reduction_type = "sum" if reduce else "none"
            self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type,
                                             zero_infinity=True)
        elif self.ctc_type == "warpctc":
            import warpctc_pytorch as warp_ctc

            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                'ctc_type must be "builtin" or "warpctc": {}'.format(
                    self.ctc_type))

        self.ignore_id = -1
        self.reduce = reduce
예제 #7
0
    def __init__(self,
                 odim,
                 eprojs,
                 dropout_rate,
                 ctc_type='warpctc',
                 reduce=True):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.loss = None
        self.ctc_lo = torch.nn.Linear(eprojs, odim)

        # In case of Pytorch >= 1.2.0, CTC will be always builtin
        self.ctc_type = ctc_type if LooseVersion(
            torch.__version__) < LooseVersion('1.2.0') else 'builtin'
        if ctc_type != self.ctc_type:
            logging.warning(
                f'CTC was set to {self.ctc_type} due to PyTorch version.')
        if self.ctc_type == 'builtin':
            reduction_type = 'sum' if reduce else 'none'
            self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
        elif self.ctc_type == 'warpctc':
            import warpctc_pytorch as warp_ctc
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                'ctc_type must be "builtin" or "warpctc": {}'.format(
                    self.ctc_type))

        self.ignore_id = -1
        self.reduce = reduce
    def __init__(self,
                 odim,
                 eprojs,
                 dropout_rate,
                 ctc_type='builtin',
                 reduce=True,
                 ignore_id=0):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.loss = None
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        torch.nn.init.xavier_normal_(self.ctc_lo.weight)
        self.ctc_type = ctc_type

        if self.ctc_type == 'builtin':
            reduction_type = 'sum' if reduce else 'mean'
            self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True,
                                             reduction=reduction_type)
        elif self.ctc_type == 'warpctc':
            import warpctc_pytorch as warp_ctc
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                'ctc_type must be "builtin" or "warpctc": {}'.format(
                    self.ctc_type))

        self.ignore_id = ignore_id
        self.reduce = reduce
예제 #9
0
파일: ctc.py 프로젝트: yuekaizhang/espnet
    def __init__(
        self,
        odim: int,
        encoder_output_sizse: int,
        dropout_rate: float = 0.0,
        ctc_type: str = "builtin",
        reduce: bool = True,
        ignore_nan_grad: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        eprojs = encoder_output_sizse
        self.dropout_rate = dropout_rate
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        self.ctc_type = ctc_type
        self.ignore_nan_grad = ignore_nan_grad

        if self.ctc_type == "builtin":
            self.ctc_loss = torch.nn.CTCLoss(reduction="none")
        elif self.ctc_type == "warpctc":
            import warpctc_pytorch as warp_ctc

            if ignore_nan_grad:
                logging.warning(
                    "ignore_nan_grad option is not supported for warp_ctc")
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        else:
            raise ValueError(
                f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}')

        self.reduce = reduce
예제 #10
0
    def __init__(self,
                 idim,
                 hidden_dim,
                 hidden_layers,
                 odim,
                 dropout_rate,
                 label_ignore_id=-1):
        super(CTCArch, self).__init__()
        self.idim = idim
        self.odim = odim
        self.loss = None
        self.loss_fn = warp_ctc.CTCLoss(
            size_average=False)  # normalize the loss by batch size if True
        self.ignore_id = label_ignore_id

        self.nblstm = torch.nn.LSTM(idim,
                                    hidden_dim,
                                    hidden_layers,
                                    batch_first=True,
                                    dropout=dropout_rate,
                                    bidirectional=True)
        self.l_proj = torch.nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout_layer = torch.nn.Dropout(
            p=dropout_rate)  # dropout for the BLSTM proj layer
        self.l_output = torch.nn.Linear(hidden_dim, odim)
예제 #11
0
def test_ctc_loss_probabilities_match_labels_third_baidu_example_variant_two_extra_padding_wrong_side(
):

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")
    # https://stackoverflow.com/questions/48915810/pytorch-contiguous
    probs = torch.FloatTensor(
        [
            [[0, 0, 0, 0, 0],
             [0, 0, 0, 0,
              0]],  # Extra padding is added at the top, which is wrong
            [[0, 0, 0, 0, 0], [1, 2, 3, 4, 5]],
            [[0, 0, 0, 0, 0], [6, 7, 8, 9, 10]],
            [[0, 0, 0, 0, 0], [11, 12, 13, 14, 15]],
        ]
    )  # .contiguous() # contiguous is just for performance, does not change results

    print("probs.size(): " + str(probs.size()))

    # labels = Variable(torch.IntTensor([ [1, 0], [3, 3], [2, 3]]))
    # See: https://github.com/SeanNaren/warp-ctc/issues/29
    # IMPORTANT !!!: All label sequences are concatenated, without blanks/padding,
    # and label sizes lists the sizes without padding
    labels = Variable(torch.IntTensor([1, 3, 3]))
    # Labels sizes should be equal to number of labels. Because labels are
    # concatenated, the label sizes essentially instructs where the sequence
    # boundaries are!
    label_sizes = Variable(torch.IntTensor([1, 2]))
    # Prob_sizes instructs on the number of real probabilities, distinguishing
    # real probabilities from padding
    # Padding should presumably
    # (looking at https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md)
    # be at the bottom, but this should be checked
    probs_sizes = Variable(torch.IntTensor([1, 3]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]), lr=0.001)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    # cost: tensor([ 7.3557]) as in the Baidu tutorial, second example
    print("cost: " + str(cost))
    # Since padding has been added to the wrong side (top instead of bottom)
    # the results are now expected to change
    no_longer_expected_cost_tensor = torch.FloatTensor([8.965181350708008])
    print("zeros_tensor: " + str(no_longer_expected_cost_tensor))
    if TensorUtils.tensors_are_equal(no_longer_expected_cost_tensor, cost):
        raise RuntimeError("Error: cost expected to be not equal to " +
                           str(no_longer_expected_cost_tensor) + "but was:" +
                           str((float(cost))))
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))

    print(
        ">>> Success: test_ctc_loss_probabilities_match_labels_third_baidu_example_variant_two_extra_padding_wrong_side"
    )
예제 #12
0
파일: ctc.py 프로젝트: ishine/neural_sp
    def __init__(self,
                 eos,
                 blank,
                 enc_n_units,
                 vocab,
                 dropout=0.,
                 lsm_prob=0.,
                 fc_list=None,
                 param_init=0.1,
                 backward=False):

        super(CTC, self).__init__()

        self.eos = eos
        self.blank = blank
        self.vocab = vocab
        self.lsm_prob = lsm_prob
        self.bwd = backward

        self.space = -1  # TODO(hirofumi): fix later

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None

        # for posterior plot
        self.prob_dict = {}
        self.data_dict = {}

        # Fully-connected layers before the softmax
        if fc_list is not None and len(fc_list) > 0:
            _fc_list = [int(fc) for fc in fc_list.split('_')]
            fc_layers = OrderedDict()
            for i in range(len(_fc_list)):
                input_dim = enc_n_units if i == 0 else _fc_list[i - 1]
                fc_layers['fc' + str(i)] = nn.Linear(input_dim, _fc_list[i])
                fc_layers['dropout' + str(i)] = nn.Dropout(p=dropout)
            fc_layers['fc' + str(len(_fc_list))] = nn.Linear(
                _fc_list[-1], vocab)
            self.output = nn.Sequential(fc_layers)
        else:
            self.output = nn.Linear(enc_n_units, vocab)

        self.use_warpctc = LooseVersion(
            torch.__version__) < LooseVersion("1.4.0")
        if self.use_warpctc:
            import warpctc_pytorch
            self.ctc_loss = warpctc_pytorch.CTCLoss(size_average=True)
        else:
            if LooseVersion(torch.__version__) < LooseVersion("1.7.0"):
                self.ctc_loss = nn.CTCLoss(reduction="sum")
            else:
                self.ctc_loss = nn.CTCLoss(reduction="sum", zero_infinity=True)

        self.forced_aligner = CTCForcedAligner()
예제 #13
0
def test_ctc_loss_probabilities_match_labels_third_baidu_example_variant():

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")
    probs = torch.FloatTensor([[[1, 2, 3, 4, 5], [0, 0, 0, 0, 0],
                                [-5, -4, -3, -2, -1]],
                               [[6, 7, 8, 9, 10], [0, 0, 0, 0, 0],
                                [-10, -9, -8, -7, -6]],
                               [[11, 12, 13, 14, 15], [0, 0, 0, 0, 0],
                                [-15, -14, -13, -12, -11]]]).contiguous()

    # probs = torch.FloatTensor([
    #     [[-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11]]
    # ]). \
    #    transpose(0, 1).contiguous()

    print("probs.size(): " + str(probs.size()))

    # labels = Variable(torch.IntTensor([ [1, 0], [3, 3], [2, 3]]))
    # See: https://github.com/SeanNaren/warp-ctc/issues/29
    # All label sequences are concatenated, without blanks/padding,
    # and label sizes lists the sizes without padding
    labels = Variable(torch.IntTensor([3, 3, 1, 2, 3]))
    # labels = Variable(torch.IntTensor([2, 3]))
    #labels = Variable(torch.IntTensor([3, 3]))
    # Labels sizes should be equal to number of labels
    label_sizes = Variable(torch.IntTensor([2, 1, 2]))
    #label_sizes = Variable(torch.IntTensor([2]))
    # This one must be equal to the number of probabilities to avoid a crash
    probs_sizes = Variable(torch.IntTensor([3, 1, 3]))
    # probs_sizes = Variable(torch.IntTensor([3]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]), lr=0.001)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    # cost: tensor([ 7.3557]) as in the Baidu tutorial, second example
    print("cost: " + str(cost))
    expected_cost_tensor = torch.FloatTensor([13.904030799865723])
    print("zeros_tensor: " + str(expected_cost_tensor))
    if not TensorUtils.tensors_are_equal(expected_cost_tensor, cost):
        raise RuntimeError("Error: cost expected to be " +
                           str(expected_cost_tensor) + "but was:" +
                           str((float(cost))))
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))

    print(
        ">>> Success: test_ctc_loss_probabilities_match_labels_third_baidu_example_variant"
    )
예제 #14
0
def test_CTCLoss():
    probs = torch.FloatTensor([[
        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
    ]]).transpose(0, 1).contiguous().cuda()
    labels = torch.IntTensor([1, 2])
    label_sizes = torch.IntTensor([2])
    probs_sizes = torch.IntTensor([2])
    probs.requires_grad_(True)

    ctc_loss = warp_ctc.CTCLoss()
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    cost.backward()
예제 #15
0
    def __init__(self, trainer_params, args):
        self.args = args
        self.trainer_params = trainer_params

        random.seed(trainer_params.random_seed)
        torch.manual_seed(trainer_params.random_seed)
        if args.cuda:
            torch.cuda.manual_seed_all(trainer_params.random_seed)

        self.train_data = seq_mnist_train(trainer_params)
        self.val_data = seq_mnist_val(trainer_params)

        self.train_loader = DataLoader(self.train_data, batch_size=trainer_params.batch_size, \
                                        shuffle=True, num_workers=trainer_params.num_workers)

        self.val_loader = DataLoader(self.val_data, batch_size=trainer_params.test_batch_size, \
                                        shuffle=False, num_workers=trainer_params.num_workers)

        self.starting_epoch = 1
        self.prev_loss = 10000

        self.model = BiLSTM(trainer_params)
        self.criterion = wp.CTCLoss(size_average=False)
        self.labels = [i for i in range(trainer_params.num_classes - 1)]
        self.decoder = seq_mnist_decoder(labels=self.labels)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=trainer_params.lr)

        if args.cuda:
            torch.cuda.set_device(args.gpus)
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if args.resume or args.eval or args.export:
            print("Loading model from {}".format(args.resume))
            package = torch.load(args.resume,
                                 map_location=lambda storage, loc: storage)
            self.model.load_state_dict(package['state_dict'])
            self.optimizer.load_state_dict(package['optim_dict'])
            self.starting_epoch = package['starting_epoch']
            self.prev_loss = package['prev_loss']
            if args.cuda:
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        if args.init_bn_fc_fusion:
            if not trainer_params.prefused_bn_fc:
                self.model.batch_norm_fc.init_fusion()
                self.trainer_params.prefused_bn_fc = True
            else:
                raise Exception("BN and FC are already fused.")
예제 #16
0
def test_ctc_loss_probabilities_match_labels_three():

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")

    # Gives no loss
    probs = torch.FloatTensor([[[0, 100, 0, 0, 83],
                                [0, 0, 100, 0, 0],
                                [0, 0, 0, 100, 0]]]).\
        transpose(0, 1).contiguous()

    # # Gives small loss
    # probs = torch.FloatTensor([[[0, 100, 0, 0, 84],
    #                             [0, 0, 100, 0, 0],
    #                             [0, 0, 0, 100, 0]]]). \
    #     transpose(0, 1).contiguous()

    print("probs.size(): " + str(probs.size()))

    # No loss
    # labels = Variable(torch.IntTensor([1, 2, 3]))
    # Also no loss (possibly because not possible!)
    # becomes effectively 2-2-2-2 which is length 6!
    # labels = Variable(torch.IntTensor([2, 2, 2, 2]))
    # labels (becomes 2-2) (Why is loss also zero?)
    labels = Variable(torch.IntTensor([1, 1, 1]))
    # Labels sizes should be equal to the number of labels in the example
    label_sizes = Variable(torch.IntTensor([3]))
    # This one must be equal to the number of probabilities to avoid a crash
    probs_sizes = Variable(torch.IntTensor([3]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]), lr=0.001)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    # cost: tensor([ 7.3557]) as in the Baidu tutorial, second example
    print("cost: " + str(cost))
    expected_cost_tensor = torch.FloatTensor([0])
    print("zeros_tensor: " + str(expected_cost_tensor))
    if not TensorUtils.tensors_are_equal(expected_cost_tensor, cost):
        raise RuntimeError("Error: cost expected to be " +
                           str(expected_cost_tensor) + "but was:" +
                           str((float(cost))))
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))
예제 #17
0
    def compute_ctc_loss_version_two(self, probabilities, labels_row_tensor):
        ctc_loss = warpctc_pytorch.CTCLoss()

        #probs = torch.FloatTensor([
        #    [[0, 0, 0, 0, 0], [1, 2, 3, 4, 5], [-5, -4, -3, -2, -1]],
        #    [[0, 0, 0, 0, 0], [6, 7, 8, 9, 10], [-10, -9, -8, -7, -6]],
        #    [[0, 0, 0, 0, 0], [11, 12, 13, 14, 15], [-15, -14, -13, -12, -11]]
        #])

        probs = probabilities

        print(
            "test_ctc_loss_probabilities_match_labels_third_baidu_example - probs: "
            + str(probs))

        print(
            "test_ctc_loss_probabilities_match_labels_third_baidu_example - probs.size(): "
            + str(probs.size()))

        # labels = Variable(torch.IntTensor([ [1, 0], [3, 3], [2, 3]]))
        # See: https://github.com/SeanNaren/warp-ctc/issues/29
        # All label sequences are concatenated, without blanks/padding,
        # and label sizes lists the sizes without padding
        labels = Variable(torch.IntTensor([1, 3, 3, 2, 3]))
        # labels = Variable(torch.IntTensor([2, 3]))
        # labels = Variable(torch.IntTensor([3, 3]))
        # Labels sizes should be equal to number of labels
        label_sizes = Variable(torch.IntTensor([1, 2, 2]))
        # label_sizes = Variable(torch.IntTensor([2]))
        # This one must be equal to the number of probabilities to avoid a crash
        probs_sizes = Variable(torch.IntTensor([1, 3, 3]))
        # probs_sizes = Variable(torch.IntTensor([3]))
        probs = Variable(probs, requires_grad=True
                         )  # tells autograd to compute gradients for probs
        print("probs: " + str(probs))

        if Utils.use_cuda():
            probs = probs.cuda()
            device = probs.get_device()
            ctc_loss = ctc_loss.cuda()
            # labels = labels.cuda()
            # label_sizes = label_sizes.cuda()
            # probs_sizes = probs_sizes.cuda()

        loss = ctc_loss(probs, labels, probs_sizes, label_sizes)
        print("loss: " + str(loss))

        return loss
예제 #18
0
def test_ctc_loss_probabilities_match_labels():

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")
    probs = torch.FloatTensor([[[0.9, 1.0, 0.0, 0.0],
                                [0.1, 0.0, 1.0, 1.0]]]).\
        transpose(0, 1).contiguous()

    print("probs.size(): " + str(probs.size()))
    # No cost
    labels = Variable(torch.IntTensor([1, 1, 2, 1]))
    # No cost
    labels = Variable(torch.IntTensor([1, 1, 1, 1]))
    # Cost
    labels = Variable(torch.IntTensor([1, 2, 2, 1]))
    # No cost
    labels = Variable(torch.IntTensor([1, 1]))
    # No cost
    labels = Variable(torch.IntTensor([2, 2]))
    # Crash (Apparently must be minimally 2 elements)
    labels = Variable(torch.IntTensor([2]))
    # No cost
    labels = Variable(torch.IntTensor([3, 3]))

    label_sizes = Variable(torch.IntTensor([2]))
    # This one must be equal to the number of probabilities to avoid a crash
    probs_sizes = Variable(torch.IntTensor([2]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]), lr=0.001)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    print("cost: " + str(cost))
    zero_tensor = torch.zeros(1)
    print("zeros_tensor: " + str(zero_tensor))
    if not TensorUtils.tensors_are_equal(zero_tensor, cost):
        raise RuntimeError(
            "Error: loss expected to be zero, since probabilities " +
            "are maximum for the right labels, but not the case")
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))
예제 #19
0
def test_ctc_loss(in_length, out_length, use_warpctc):
    pytest.importorskip("torch")
    if use_warpctc:
        pytest.importorskip("warpctc_pytorch")
        import warpctc_pytorch

        torch_ctcloss = warpctc_pytorch.CTCLoss(size_average=True)
    else:
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            pytest.skip("pytorch < 1.0 doesn't support CTCLoss")
        _ctcloss_sum = torch.nn.CTCLoss(reduction="sum")

        def torch_ctcloss(th_pred, th_target, th_ilen, th_olen):
            th_pred = th_pred.log_softmax(2)
            loss = _ctcloss_sum(th_pred, th_target, th_ilen, th_olen)
            # Batch-size average
            loss = loss / th_pred.size(1)
            return loss

    n_out = 7
    input_length = numpy.array(in_length, dtype=numpy.int32)
    label_length = numpy.array(out_length, dtype=numpy.int32)
    np_pred = [
        numpy.random.rand(il, n_out).astype(numpy.float32)
        for il in input_length
    ]
    np_target = [
        numpy.random.randint(0, n_out, size=ol, dtype=numpy.int32)
        for ol in label_length
    ]

    # NOTE: np_pred[i] seems to be transposed and used axis=-1 in e2e_asr.py
    ch_pred = F.separate(F.pad_sequence(np_pred), axis=-2)
    ch_target = F.pad_sequence(np_target, padding=-1)
    ch_loss = F.connectionist_temporal_classification(ch_pred, ch_target, 0,
                                                      input_length,
                                                      label_length).data

    th_pred = pad_list([torch.from_numpy(x) for x in np_pred],
                       0.0).transpose(0, 1)
    th_target = torch.from_numpy(numpy.concatenate(np_target))
    th_ilen = torch.from_numpy(input_length)
    th_olen = torch.from_numpy(label_length)
    th_loss = torch_ctcloss(th_pred, th_target, th_ilen, th_olen).numpy()
    numpy.testing.assert_allclose(th_loss, ch_loss, 0.05)
예제 #20
0
    def __init__(self, output_size, hidden_size, dropout, ctc_type='warpctc', reduce=True):
        super().__init__()
        self.dropout = dropout
        self.loss = None
        self.ctc_lo = torch.nn.Linear(hidden_size, output_size)
        self.ctc_type = ctc_type

        if self.ctc_type == 'builtin':
            reduction_type = 'sum' if reduce else 'none'
            self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
        elif self.ctc_type == 'warpctc':
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True)
        else:
            raise ValueError('ctc_type must be "builtin" or "warpctc": {}'
                             .format(self.ctc_type))

        self.ignore_id = -1
        self.reduce = reduce
예제 #21
0
    def __init__(
        self,
        odim: int,
        encoder_output_size: int,
        dropout_rate: float = 0.0,
        ctc_type: str = "builtin",
        reduce: bool = True,
        ignore_nan_grad: bool = None,
        zero_infinity: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        eprojs = encoder_output_size
        self.dropout_rate = dropout_rate
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        self.ctc_type = ctc_type
        if ignore_nan_grad is not None:
            zero_infinity = ignore_nan_grad

        if self.ctc_type == "builtin":
            self.ctc_loss = torch.nn.CTCLoss(
                reduction="none", zero_infinity=zero_infinity
            )
        elif self.ctc_type == "warpctc":
            import warpctc_pytorch as warp_ctc

            if zero_infinity:
                logging.warning("zero_infinity option is not supported for warp_ctc")
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)

        elif self.ctc_type == "gtnctc":
            from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction

            self.ctc_loss = GTNCTCLossFunction.apply
        else:
            raise ValueError(
                f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
            )

        self.reduce = reduce
예제 #22
0
    def __init__(self,
                 eos,
                 blank,
                 enc_n_units,
                 vocab,
                 dropout=0.,
                 lsm_prob=0.,
                 fc_list=None,
                 param_init=0.1,
                 backward=False):

        super(CTC, self).__init__()

        self.eos = eos
        self.blank = blank
        self.vocab = vocab
        self.lsm_prob = lsm_prob
        self.bwd = backward

        self.space = -1  # TODO(hirofumi): fix later

        # Fully-connected layers before the softmax
        if fc_list is not None and len(fc_list) > 0:
            _fc_list = [int(fc) for fc in fc_list.split('_')]
            fc_layers = OrderedDict()
            for i in range(len(_fc_list)):
                input_dim = enc_n_units if i == 0 else _fc_list[i - 1]
                fc_layers['fc' + str(i)] = nn.Linear(input_dim, _fc_list[i])
                fc_layers['dropout' + str(i)] = nn.Dropout(p=dropout)
            fc_layers['fc' + str(len(_fc_list))] = nn.Linear(
                _fc_list[-1], vocab)
            self.output = nn.Sequential(fc_layers)
        else:
            self.output = nn.Linear(enc_n_units, vocab)

        import warpctc_pytorch
        self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        self.forced_aligner = CTCForcedAligner()
예제 #23
0
    def __init__(self,
                 eos,
                 blank,
                 enc_n_units,
                 vocab,
                 dropout=0.0,
                 lsm_prob=0.0,
                 fc_list=[],
                 param_init=0.1):

        super(CTC, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.blank = blank
        self.vocab = vocab
        self.lsm_prob = lsm_prob

        self.space = -1
        # TODO(hirofumi): fix layer

        # Fully-connected layers before the softmax
        if len(fc_list) > 0:
            fc_layers = OrderedDict()
            for i in range(len(fc_list)):
                input_dim = enc_n_units if i == 0 else fc_list[i - 1]
                fc_layers['fc' + str(i)] = LinearND(input_dim,
                                                    fc_list[i],
                                                    dropout=dropout)
            fc_layers['fc' + str(len(fc_list))] = LinearND(fc_list[-1],
                                                           vocab,
                                                           dropout=0)
            self.output = nn.Sequential(fc_layers)
        else:
            self.output = LinearND(enc_n_units, vocab)

        import warpctc_pytorch
        self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)
예제 #24
0
    def __init__(self, dropout_rate, ctc_type="builtin", reduce=True):
        super().__init__()

        # self.vocab_size = vocab_size
        # self.hidden_size = hidden_size
        self.padding_idx = onmt.constants.PAD

        # why do we need dropout at ctc ?
        self.dropout_rate = dropout_rate

        # In case of Pytorch >= 1.7.0, CTC will be always builtin
        self.ctc_type = (
            ctc_type
            if LooseVersion(torch.__version__) < LooseVersion("1.7.0")
            else "builtin"
        )

        if ctc_type != self.ctc_type:
            logging.warning(f"CTC was set to {self.ctc_type} due to PyTorch version.")

        if self.ctc_type == "builtin":
            reduction_type = "sum" if reduce else "none"
            self.ctc_loss = torch.nn.CTCLoss(blank=onmt.constants.PAD,
                                             reduction=reduction_type, zero_infinity=True)

        elif self.ctc_type == "warpctc":
            import warpctc_pytorch as warp_ctc

            self.ctc_loss = warp_ctc.CTCLoss(size_average=False, length_average=False)

        else:
            raise ValueError(
                'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type)
            )

        self.ignore_id = -1
        self.reduce = reduce
예제 #25
0
def test_ctc_loss_probabilities_match_labels_second_baidu_example():

    ctc_loss = warpctc_pytorch.CTCLoss()
    print("expected shape of seqLength x batchSize x alphabet_size")
    probs = torch.FloatTensor([[[1, 2, 3, 4, 5],
                                [6, 7, 8, 9, 10],
                                [11, 12, 13, 14, 15]]]).\
        transpose(0, 1).contiguous()

    probs.requires_grad_(True)

    print("probs.size(): " + str(probs.size()))

    labels = Variable(torch.IntTensor([3, 3]))
    # Labels sizes should be equal to number of labels
    label_sizes = Variable(torch.IntTensor([2]))
    # This one must be equal to the number of probabilities to avoid a crash
    probs_sizes = Variable(torch.IntTensor([3]))
    probs = Variable(
        probs,
        requires_grad=True)  # tells autograd to compute gradients for probs
    optimizer = optim.SGD(list([probs]), lr=0.001)
    print("probs: " + str(probs))
    cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
    # cost: tensor([ 7.3557]) as in the Baidu tutorial, second example
    print("cost: " + str(cost))
    expected_cost_tensor = torch.FloatTensor([7.355742931365967])
    print("zeros_tensor: " + str(expected_cost_tensor))
    if not TensorUtils.tensors_are_equal(expected_cost_tensor, cost):
        raise RuntimeError("Error: cost expected to be " +
                           str(expected_cost_tensor) + "but was:" +
                           str((float(cost))))
    cost.backward()
    print("cost: " + str(cost))
    print("update probabilities...")
    optimizer.step()
    print("probs: " + str(probs))
예제 #26
0
    def __init__(self,
                 sos,
                 eos,
                 pad,
                 enc_nunits,
                 attn_type,
                 attn_dim,
                 attn_sharpening_factor,
                 attn_sigmoid_smoothing,
                 attn_conv_out_channels,
                 attn_conv_kernel_size,
                 attn_nheads,
                 dropout_att,
                 rnn_type,
                 nunits,
                 nlayers,
                 residual,
                 emb_dim,
                 tie_embedding,
                 vocab,
                 logits_temp,
                 dropout,
                 dropout_emb,
                 ss_prob,
                 lsm_prob,
                 layer_norm,
                 fl_weight,
                 fl_gamma,
                 init_with_enc=False,
                 ctc_weight=0.,
                 ctc_fc_list=[],
                 input_feeding=False,
                 backward=False,
                 rnnlm_cold_fusion=False,
                 cold_fusion='hidden',
                 internal_lm=False,
                 rnnlm_init=False,
                 lmobj_weight=0.,
                 share_lm_softmax=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 vocab_char=None):

        super(Decoder, self).__init__()

        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm', 'gru']
        self.enc_nunits = enc_nunits
        self.nunits = nunits
        self.nlayers = nlayers
        self.residual = residual
        self.logits_temp = logits_temp
        self.dropout = dropout
        self.dropout_emb = dropout_emb
        self.ss_prob = ss_prob
        self.lsm_prob = lsm_prob
        self.layer_norm = layer_norm
        self.fl_weight = fl_weight
        self.fl_gamma = fl_gamma
        self.init_with_enc = init_with_enc
        self.ctc_weight = ctc_weight
        self.ctc_fc_list = ctc_fc_list
        self.backward = backward
        self.rnnlm_cf = rnnlm_cold_fusion
        self.cold_fusion = cold_fusion
        self.internal_lm = internal_lm
        self.rnnlm_init = rnnlm_init
        self.lmobj_weight = lmobj_weight
        self.share_lm_softmax = share_lm_softmax
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0 and not backward:
            # Fully-connected layers for CTC
            if len(ctc_fc_list) > 0:
                fc_layers = OrderedDict()
                for i in range(len(ctc_fc_list)):
                    input_dim = enc_nunits if i == 0 else ctc_fc_list[i - 1]
                    fc_layers['fc' + str(i)] = LinearND(input_dim,
                                                        ctc_fc_list[i],
                                                        dropout=dropout)
                fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(
                    ctc_fc_list[-1], vocab, dropout=0)
                self.output_ctc = nn.Sequential(fc_layers)
            else:
                self.output_ctc = LinearND(enc_nunits, vocab)
            self.decode_ctc_greedy = GreedyDecoder(blank_index=0)
            self.decode_ctc_beam = BeamSearchDecoder(blank_index=0)
            self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        if ctc_weight < global_weight:
            # Attention layer
            if attn_nheads > 1:
                self.score = MultiheadAttentionMechanism(
                    enc_nunits=self.enc_nunits,
                    dec_nunits=nunits,
                    attn_type=attn_type,
                    attn_dim=attn_dim,
                    sharpening_factor=attn_sharpening_factor,
                    sigmoid_smoothing=attn_sigmoid_smoothing,
                    conv_out_channels=attn_conv_out_channels,
                    conv_kernel_size=attn_conv_kernel_size,
                    nheads=attn_nheads,
                    dropout=dropout_att)
            else:
                self.score = AttentionMechanism(
                    enc_nunits=self.enc_nunits,
                    dec_nunits=nunits,
                    attn_type=attn_type,
                    attn_dim=attn_dim,
                    sharpening_factor=attn_sharpening_factor,
                    sigmoid_smoothing=attn_sigmoid_smoothing,
                    conv_out_channels=attn_conv_out_channels,
                    conv_kernel_size=attn_conv_kernel_size,
                    dropout=dropout_att)

            # for decoder initialization with pre-trained RNNLM
            if rnnlm_init:
                assert internal_lm
                assert rnnlm_init.predictor.vocab == vocab
                assert rnnlm_init.predictor.nunits == nunits
                assert rnnlm_init.predictor.nlayers == 1  # TODO(hirofumi): on-the-fly

            # for MTL with RNNLM objective
            if lmobj_weight > 0:
                if internal_lm and not share_lm_softmax:
                    self.output_lmobj = LinearND(nunits, vocab)

            # Decoder
            self.rnn = torch.nn.ModuleList()
            self.dropout = torch.nn.ModuleList()
            if rnn_type == 'lstm':
                rnn_cell = nn.LSTMCell
            elif rnn_type == 'gru':
                rnn_cell = nn.GRUCell
            if internal_lm:
                assert nlayers >= 2
                self.rnn_inlm = rnn_cell(emb_dim, nunits)
                self.dropout_inlm = nn.Dropout(p=dropout)
                self.rnn += [rnn_cell(nunits + enc_nunits, nunits)]
            else:
                self.rnn += [rnn_cell(emb_dim + enc_nunits, nunits)]
            self.dropout += [nn.Dropout(p=dropout)]
            for l in range(1, nlayers):
                self.rnn += [rnn_cell(nunits, nunits)]
                self.dropout += [nn.Dropout(p=dropout)]

            # cold fusion
            if rnnlm_cold_fusion:
                self.cf_linear_dec_feat = LinearND(nunits + enc_nunits, nunits)
                if cold_fusion == 'hidden':
                    self.cf_linear_lm_feat = LinearND(rnnlm_cold_fusion.nunits,
                                                      nunits)
                elif cold_fusion == 'prob':
                    self.cf_linear_lm_feat = LinearND(rnnlm_cold_fusion.vocab,
                                                      nunits)
                else:
                    raise ValueError(cold_fusion)
                self.cf_linear_lm_gate = LinearND(nunits * 2, nunits)
                self.output_bn = LinearND(nunits * 2, nunits)

                # fix RNNLM parameters
                for p in self.rnnlm_cf.parameters():
                    p.requires_grad = False
            else:
                self.output_bn = LinearND(nunits + enc_nunits, nunits)

            self.output = LinearND(nunits, vocab)

            # Embedding
            self.embed = Embedding(vocab=vocab,
                                   emb_dim=emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if tie_embedding:
                if nunits != emb_dim:
                    raise ValueError(
                        'When using the tied flag, nunits must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight
예제 #27
0
                          num_workers=8)
test_loader = DataLoader(test_set,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=8)

# load CNN
logger.info('Preparing Net...')

net = HTRNet(cnn_cfg, rnn_cfg, len(classes))

if load_model_name is not None:
    my_torch_load(net, load_model_name)
net.cuda(args.gpu_id)

loss = warp_ctc.CTCLoss()
net_parameters = net.parameters()
nlr = args.learning_rate
optimizer = torch.optim.Adam(net_parameters, nlr, weight_decay=0.00005)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    [int(.5 * max_epochs), int(.75 * max_epochs)])

decoder = ctcdecode.CTCBeamDecoder([c for c in classes], beam_width=100)
# decoder = ctcdecode.


def train(epoch):
    optimizer.zero_grad()

    closs = []
예제 #28
0
    def __init__(self, eos, unk, pad, blank, enc_nunits, attn_type,
                 attn_nheads, n_layers, d_model, d_ff, pe_type, tie_embedding,
                 vocab, dropout, dropout_emb, dropout_att, lsm_prob,
                 layer_norm_eps, ctc_weight, ctc_fc_list, backward,
                 global_weight, mtl_per_batch):

        super(TransformerDecoder, self).__init__()

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_nunits = enc_nunits
        self.d_model = d_model
        self.n_layers = n_layers
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.ctc_fc_list = ctc_fc_list
        self.backward = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            # Fully-connected layers for CTC
            if len(ctc_fc_list) > 0:
                fc_layers = OrderedDict()
                for i in range(len(ctc_fc_list)):
                    input_dim = d_model if i == 0 else ctc_fc_list[i - 1]
                    fc_layers['fc' + str(i)] = LinearND(input_dim,
                                                        ctc_fc_list[i],
                                                        dropout=dropout)
                fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(
                    ctc_fc_list[-1], vocab, dropout=0)
                self.output_ctc = nn.Sequential(fc_layers)
            else:
                self.output_ctc = LinearND(d_model, vocab)
            self.decode_ctc_greedy = GreedyDecoder(blank=blank)
            self.decode_ctc_beam = BeamSearchDecoder(blank=blank)
            self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        if ctc_weight < global_weight:
            self.layers = nn.ModuleList([
                TransformerDecoderBlock(d_model, d_ff, attn_type, attn_nheads,
                                        dropout, dropout_att, layer_norm_eps)
                for _ in range(n_layers)
            ])

            self.embed = Embedding(
                vocab,
                d_model,
                dropout=0,  # NOTE: do not apply dropout here
                ignore_index=pad)
            if pe_type:
                self.pos_emb_out = PositionalEncoding(d_model, dropout_emb,
                                                      pe_type)
            self.output = LinearND(d_model, vocab)

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if tie_embedding:
                self.output.fc.weight.data = self.embed.embed.weight.data

            self.layer_norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)
예제 #29
0
    def __init__(self,
                 model,
                 amp_handle=None,
                 init_lr=1e-2,
                 max_norm=100,
                 use_cuda=False,
                 fp16=False,
                 log_dir='logs',
                 model_prefix='model',
                 checkpoint=False,
                 continue_from=None,
                 opt_type=None,
                 *args,
                 **kwargs):
        if fp16:
            import apex.parallel
            from apex import amp
            if not use_cuda:
                raise RuntimeError
        self.amp_handle = amp_handle

        # training parameters
        self.init_lr = init_lr
        self.max_norm = max_norm
        self.use_cuda = use_cuda
        self.fp16 = fp16
        self.log_dir = log_dir
        self.model_prefix = model_prefix
        self.checkpoint = checkpoint
        self.opt_type = opt_type
        self.epoch = 0
        self.states = None

        # load from pre-trained model if needed
        if continue_from is not None:
            self.load(continue_from)

        # setup model
        self.model = model
        if self.use_cuda:
            logger.debug("using cuda")
            self.model.cuda()

        # setup loss
        #self.loss = nn.CTCLoss(blank=0, reduction='none')
        self.loss = wp.CTCLoss(blank=0, length_average=True)

        # setup optimizer
        if opt_type is None:
            # for test only
            self.optimizer = None
            self.lr_scheduler = None
        else:
            assert opt_type in OPTIMIZER_TYPES
            parameters = self.model.parameters()
            if opt_type == "sgdr":
                logger.debug("using SGDR")
                self.optimizer = torch.optim.SGD(parameters,
                                                 lr=self.init_lr,
                                                 momentum=0.9,
                                                 weight_decay=5e-4)
                #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adamwr":
                logger.debug("using AdamWR")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = CosineAnnealingWithRestartsLR(
                    self.optimizer, T_max=2, T_mult=2)
            elif opt_type == "adam":
                logger.debug("using Adam")
                self.optimizer = torch.optim.Adam(parameters,
                                                  lr=self.init_lr,
                                                  betas=(0.9, 0.999),
                                                  eps=1e-8,
                                                  weight_decay=5e-4)
                self.lr_scheduler = None
            elif opt_type == "rmsprop":
                logger.debug("using RMSprop")
                self.optimizer = torch.optim.RMSprop(parameters,
                                                     lr=self.init_lr,
                                                     alpha=0.95,
                                                     eps=1e-8,
                                                     weight_decay=5e-4,
                                                     centered=True)
                self.lr_scheduler = None

        # setup decoder for test
        self.decoder = LatGenCTCDecoder()
        self.labeler = self.decoder.labeler

        # FP16 and distributed after load
        if self.fp16:
            #self.model = network_to_half(self.model)
            #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.)
            self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer)

        if is_distributed():
            if self.use_cuda:
                local_rank = torch.cuda.current_device()
                if fp16:
                    self.model = apex.parallel.DistributedDataParallel(
                        self.model)
                else:
                    self.model = nn.parallel.DistributedDataParallel(
                        self.model,
                        device_ids=[local_rank],
                        output_device=local_rank)
            else:
                self.model = nn.parallel.DistributedDataParallel(self.model)

        if self.states is not None:
            self.restore_state()
import torch.autograd
import sklearn.metrics
import warpctc_pytorch

# Local imports
import config
import logutil
import utils
import models
import datasets
import parser.grammarutils

cross_entropy = torch.nn.CrossEntropyLoss().cuda()
mse_loss = torch.nn.MSELoss().cuda()
softmax = torch.nn.Softmax(dim=2)
ctc_loss = warpctc_pytorch.CTCLoss().cuda()


def loss_func(model_outputs, labels, probs, total_lengths, args):
    loss = 0
    for i_batch in range(model_outputs.size()[1]):
        gt_pred_labels = list()
        seg_length = int(total_lengths[i_batch])
        current_label = int(labels[0, i_batch])
        for f in range(seg_length):
            if int(labels[f, i_batch]) != current_label:
                current_label = int(labels[f, i_batch])
                gt_pred_labels.extend([
                    current_label for _ in range(f - len(gt_pred_labels) - 1)
                ])
        gt_pred_labels.extend([