def decode(data_feats, data_tags, data_class, output_path):
    data_index = np.arange(len(data_feats))
    losses = []
    TP, FP, FN, TN = 0.0, 0.0, 0.0, 0.0
    TP2, FP2, FN2, TN2 = 0.0, 0.0, 0.0, 0.0
    with open(output_path, 'w') as f:
        for j in range(0, len(data_index), opt.test_batchSize):
            if opt.testing:
                words, tags, raw_tags, classes, raw_classes, lens, line_nums = data_reader.get_minibatch_with_class(
                    data_feats,
                    data_tags,
                    data_class,
                    tag_to_idx,
                    class_to_idx,
                    data_index,
                    j,
                    opt.test_batchSize,
                    add_start_end=opt.bos_eos,
                    multiClass=opt.multiClass,
                    keep_order=opt.testing,
                    enc_dec_focus=opt.enc_dec,
                    device=opt.device)
            else:
                words, tags, raw_tags, classes, raw_classes, lens = data_reader.get_minibatch_with_class(
                    data_feats,
                    data_tags,
                    data_class,
                    tag_to_idx,
                    class_to_idx,
                    data_index,
                    j,
                    opt.test_batchSize,
                    add_start_end=opt.bos_eos,
                    multiClass=opt.multiClass,
                    keep_order=opt.testing,
                    enc_dec_focus=opt.enc_dec,
                    device=opt.device)

            inputs = prepare_inputs_for_bert_xlnet(
                words,
                lens,
                tokenizer,
                cls_token_at_end=bool(opt.pretrained_model_type in ['xlnet']
                                      ),  # xlnet has a cls token at the end
                cls_token=tokenizer.cls_token,
                sep_token=tokenizer.sep_token,
                cls_token_segment_id=2
                if opt.pretrained_model_type in ['xlnet'] else 0,
                pad_on_left=bool(opt.pretrained_model_type in
                                 ['xlnet']),  # pad on the left for xlnet
                pad_token_segment_id=4
                if opt.pretrained_model_type in ['xlnet'] else 0,
                device=opt.device)

            if opt.enc_dec:
                opt.greed_decoding = True
                if opt.greed_decoding:
                    tag_scores_1best, outputs_1best, encoder_info = model_tag.decode_greed(
                        inputs, tags[:, 0:1], lens, with_snt_classifier=True)
                    tag_loss = tag_loss_function(
                        tag_scores_1best.contiguous().view(
                            -1, len(tag_to_idx)),
                        tags[:, 1:].contiguous().view(-1))
                    top_pred_slots = outputs_1best.cpu().numpy()
                else:
                    beam_size = 2
                    beam_scores_1best, top_path_slots, encoder_info = model_tag.decode_beam_search(
                        inputs,
                        lens,
                        beam_size,
                        tag_to_idx,
                        with_snt_classifier=True)
                    top_pred_slots = [[item[0].item() for item in seq]
                                      for seq in top_path_slots]
                    ppl = beam_scores_1best.cpu() / torch.tensor(
                        lens, dtype=torch.float)
                    tag_loss = ppl.exp().sum()
                #tags = tags[:, 1:].data.cpu().numpy()
            elif opt.crf:
                max_len = max(lens)
                masks = [([1] * l) + ([0] * (max_len - l)) for l in lens]
                masks = torch.tensor(masks,
                                     dtype=torch.uint8,
                                     device=opt.device)
                crf_feats, encoder_info = model_tag._get_lstm_features(
                    inputs, lens, with_snt_classifier=True)
                tag_path_scores, tag_path = model_tag.forward(crf_feats, masks)
                tag_loss = model_tag.neg_log_likelihood(crf_feats, masks, tags)
                top_pred_slots = tag_path.data.cpu().numpy()
            else:
                tag_scores, encoder_info = model_tag(inputs,
                                                     lens,
                                                     with_snt_classifier=True)
                tag_loss = tag_loss_function(
                    tag_scores.contiguous().view(-1, len(tag_to_idx)),
                    tags.view(-1))
                top_pred_slots = tag_scores.data.cpu().numpy().argmax(axis=-1)
                #tags = tags.data.cpu().numpy()
            if opt.task_sc:
                class_scores = model_class(encoder_info_filter(encoder_info))
                class_loss = class_loss_function(class_scores, classes)
                if opt.multiClass:
                    snt_probs = class_scores.data.cpu().numpy()
                else:
                    snt_probs = class_scores.data.cpu().numpy().argmax(axis=-1)
                losses.append([
                    tag_loss.item() / sum(lens),
                    class_loss.item() / len(lens)
                ])
            else:
                losses.append([tag_loss.item() / sum(lens), 0])

            #classes = classes.data.cpu().numpy()
            for idx, pred_line in enumerate(top_pred_slots):
                length = lens[idx]
                pred_seq = [idx_to_tag[tag] for tag in pred_line][:length]
                lab_seq = [
                    idx_to_tag[tag] if type(tag) == int else tag
                    for tag in raw_tags[idx]
                ]
                pred_chunks = acc.get_chunks(['O'] + pred_seq + ['O'])
                label_chunks = acc.get_chunks(['O'] + lab_seq + ['O'])
                for pred_chunk in pred_chunks:
                    if pred_chunk in label_chunks:
                        TP += 1
                    else:
                        FP += 1
                for label_chunk in label_chunks:
                    if label_chunk not in pred_chunks:
                        FN += 1

                input_line = words[idx]
                word_tag_line = [
                    input_line[_idx] + ':' + lab_seq[_idx] + ':' +
                    pred_seq[_idx] for _idx in range(len(input_line))
                ]

                if opt.task_sc:
                    if opt.multiClass:
                        pred_classes = [
                            idx_to_class[i]
                            for i, p in enumerate(snt_probs[idx]) if p > 0.5
                        ]
                        gold_classes = [
                            idx_to_class[i] for i in raw_classes[idx]
                        ]
                        for pred_class in pred_classes:
                            if pred_class in gold_classes:
                                TP2 += 1
                            else:
                                FP2 += 1
                        for gold_class in gold_classes:
                            if gold_class not in pred_classes:
                                FN2 += 1
                        gold_class_str = ';'.join(gold_classes)
                        pred_class_str = ';'.join(pred_classes)
                    else:
                        pred_class = idx_to_class[snt_probs[idx]]
                        if type(raw_classes[idx]) == int:
                            gold_classes = {idx_to_class[raw_classes[idx]]}
                        else:
                            gold_classes = set(raw_classes[idx])
                        if pred_class in gold_classes:
                            TP2 += 1
                        else:
                            FP2 += 1
                            FN2 += 1
                        gold_class_str = ';'.join(list(gold_classes))
                        pred_class_str = pred_class
                else:
                    gold_class_str = ''
                    pred_class_str = ''

                if opt.testing:
                    f.write(
                        str(line_nums[idx]) + ' : ' + ' '.join(word_tag_line) +
                        ' <=> ' + gold_class_str + ' <=> ' + pred_class_str +
                        '\n')
                else:
                    f.write(' '.join(word_tag_line) + ' <=> ' +
                            gold_class_str + ' <=> ' + pred_class_str + '\n')

    if TP == 0:
        p, r, f = 0, 0, 0
    else:
        p, r, f = 100 * TP / (TP + FP), 100 * TP / (TP + FN), 100 * 2 * TP / (
            2 * TP + FN + FP)

    mean_losses = np.mean(losses, axis=0)
    return mean_losses, p, r, f, 0 if 2 * TP2 + FN2 + FP2 == 0 else 100 * 2 * TP2 / (
        2 * TP2 + FN2 + FP2)
     class_to_idx,
     train_data_index,
     j,
     opt.batchSize,
     add_start_end=opt.bos_eos,
     multiClass=opt.multiClass,
     enc_dec_focus=opt.enc_dec,
     device=opt.device)
 inputs = prepare_inputs_for_bert_xlnet(
     words,
     lens,
     tokenizer,
     cls_token_at_end=bool(opt.pretrained_model_type in ['xlnet']
                           ),  # xlnet has a cls token at the end
     cls_token=tokenizer.cls_token,
     sep_token=tokenizer.sep_token,
     cls_token_segment_id=2
     if opt.pretrained_model_type in ['xlnet'] else 0,
     pad_on_left=bool(opt.pretrained_model_type in
                      ['xlnet']),  # pad on the left for xlnet
     pad_token_segment_id=4
     if opt.pretrained_model_type in ['xlnet'] else 0,
     device=opt.device)
 optimizer.zero_grad()
 if opt.enc_dec:
     tag_scores, encoder_info = model_tag(inputs,
                                          tags[:, :-1],
                                          lens,
                                          with_snt_classifier=True)
     tag_loss = tag_loss_function(
         tag_scores.contiguous().view(-1, len(tag_to_idx)),
Example #3
0
def eval_epoch(model, data, opt, memory, fp, efp):
    '''Epoch operation in evaluating phase'''

    model.eval()

    TP, FP, FN = 0, 0, 0
    corr, tot = 0, 0
    losses = []

    all_cases = []
    err_cases = []
    utt_id = 0

    for j, batch in enumerate(data):
        # prepare data
        batch_in, batch_pos, batch_score, \
            batch_sa, batch_sa_parent, batch_sa_sib, batch_sa_type, \
            batch_labels, raw_in, raw_sa, raw_labels = batch
        lens_utt = [len(utt) + 1 for utt in raw_in]  # +1 due to [CLS]
        lens_sysact = [len(seq) for seq in raw_sa]
        raw_lens = [len(utt) for utt in raw_in]

        # prepare inputs for BERT/XLNET
        inputs = {}
        if opt.with_system_act:
            batch_sa_pos = get_sequential_pos(batch_sa)
            pretrained_inputs = prepare_inputs_for_bert_xlnet_one_seq(
                raw_in,
                raw_lens,
                batch_pos,
                batch_score,
                raw_sa,
                lens_sysact,
                batch_sa_pos,
                batch_sa_parent,
                batch_sa_sib,
                batch_sa_type,
                opt.tokenizer,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        else:
            pretrained_inputs = prepare_inputs_for_bert_xlnet(
                raw_in,
                raw_lens,
                opt.tokenizer,
                batch_pos,
                batch_score,
                cls_token_at_end=False,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        inputs['pretrained_inputs'] = pretrained_inputs
        masks = prepare_mask(pretrained_inputs)

        # forward
        scores = model(inputs, masks, return_attns=False)
        loss, _ = cal_total_loss(scores, batch_labels, opt)
        losses.append(loss)

        # calculate performance
        batch_pred_classes = []
        batch_ids = []

        for i, (score, gold,
                raw) in enumerate(zip(scores.tolist(), raw_labels, raw_in)):
            pred_classes = pred_one_sample(score, memory, opt)

            # ontology filter
            if opt.ontology is not None:
                pred_classes = filter_informative(pred_classes, opt.ontology)
                gold = filter_informative(gold, opt.ontology)

            TP, FP, FN = update_f1(pred_classes, gold, TP, FP, FN)

            tot += 1
            if set(pred_classes) == set(gold):
                corr += 1

            batch_pred_classes.append(pred_classes)

            batch_ids.append(utt_id)
            utt_id += 1

            # keep intermediate results
            res_info = '%s\t<=>\t%s\t<=>\t%s\n' % (
                ' '.join(raw), ';'.join(pred_classes), ';'.join(gold))
            fp.write(res_info)
            if set(pred_classes) != set(gold):
                efp.write(res_info)
                err_cases.append((raw, pred_classes, gold))
            all_cases.append((raw, pred_classes, gold))

    mean_loss = np.mean(losses)
    p, r, f = compute_f1(TP, FP, FN)
    acc = corr / tot * 100

    # err_analysis(err_cases)

    if opt.testing:
        return mean_loss, (p, r, f), acc, all_cases
    else:
        return mean_loss, (p, r, f), acc
Example #4
0
def train_epoch(model, data, opt, memory):
    '''Epoch operation in training phase'''

    model.train()
    opt.optimizer.zero_grad()

    TP, FP, FN = 0, 0, 0
    corr, tot = 0, 0
    losses = []

    for step, batch in enumerate(data):
        # prepare data
        batch_in, batch_pos, batch_score, \
            batch_sa, batch_sa_parent, batch_sa_sib, batch_sa_type, \
            batch_labels, raw_in, raw_sa, raw_labels = batch
        lens_utt = [len(utt) + 1 for utt in raw_in]  # +1 due to [CLS]
        lens_sysact = [len(seq) for seq in raw_sa]
        raw_lens = [len(utt) for utt in raw_in]

        # prepare inputs for BERT/XLNET
        inputs = {}
        if opt.with_system_act:
            batch_sa_pos = get_sequential_pos(batch_sa)
            pretrained_inputs = prepare_inputs_for_bert_xlnet_one_seq(
                raw_in,
                raw_lens,
                batch_pos,
                batch_score,
                raw_sa,
                lens_sysact,
                batch_sa_pos,
                batch_sa_parent,
                batch_sa_sib,
                batch_sa_type,
                opt.tokenizer,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        else:
            pretrained_inputs = prepare_inputs_for_bert_xlnet(
                raw_in,
                raw_lens,
                opt.tokenizer,
                batch_pos,
                batch_score,
                cls_token_at_end=False,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        inputs['pretrained_inputs'] = pretrained_inputs
        masks = prepare_mask(pretrained_inputs)

        # forward
        scores = model(inputs, masks)

        # backward
        loss_record, total_loss = cal_total_loss(scores, batch_labels, opt)
        losses.append(loss_record)
        total_loss.backward()

        if (step + 1) % opt.n_accum_steps == 0:
            # clip gradient
            if opt.optim_choice.lower() != 'bertadam' and opt.max_norm > 0:
                params = list(
                    filter(lambda p: p.requires_grad,
                           list(model.parameters())))
                torch.nn.utils.clip_grad_norm_(params, opt.max_norm)

            # update parameters
            if opt.optim_choice.lower() in ['adam', 'bertadam']:
                opt.optimizer.step()
            elif opt.optim_choice.lower() == 'adamw':
                opt.optimizer.step()
                opt.scheduler.step()

            # clear gradients
            opt.optimizer.zero_grad()

        # calculate performance
        for i, (score, gold) in enumerate(zip(scores.tolist(), raw_labels)):
            pred_classes = pred_one_sample(score, memory, opt)
            TP, FP, FN = update_f1(pred_classes, gold, TP, FP, FN)
            tot += 1
            if set(pred_classes) == set(gold):
                corr += 1

    mean_loss = np.mean(losses)
    p, r, f = compute_f1(TP, FP, FN)
    acc = corr / tot * 100

    return mean_loss, (p, r, f), acc
Example #5
0
def eval_epoch(model, data, opt, memory, fp, efp):
    '''Epoch operation in evaluating phase'''

    model.eval()

    TP, FP, FN = 0, 0, 0
    corr, tot = 0, 0
    losses = []

    all_cases = []
    err_cases = []
    utt_id = 0

    for j, batch in enumerate(data):
        batch_in, batch_pos, batch_score, \
            batch_sa, batch_sa_parent, batch_sa_sib, batch_sa_type, \
            batch_labels, raw_in, raw_sa, raw_labels, \
            act_labels, act_inputs, slot_labels, act_slot_pairs, value_inps, value_outs, \
            enc_batch_extend_vocab_idx, oov_lists = batch
        lens_utt = [len(utt) + 1 for utt in raw_in]  # +1 due to [CLS]
        lens_sysact = [len(seq) for seq in raw_sa]
        raw_lens = [len(utt) for utt in raw_in]

        # prepare inputs for BERT/XLNET
        inputs = {}
        if opt.with_system_act:
            batch_sa_pos = get_sequential_pos(batch_sa)
            pretrained_inputs = prepare_inputs_for_bert_xlnet_one_seq(
                raw_in,
                raw_lens,
                batch_pos,
                batch_score,
                raw_sa,
                lens_sysact,
                batch_sa_pos,
                batch_sa_parent,
                batch_sa_sib,
                batch_sa_type,
                opt.tokenizer,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        else:
            pretrained_inputs = prepare_inputs_for_bert_xlnet(
                raw_in,
                raw_lens,
                opt.tokenizer,
                batch_pos,
                batch_score,
                cls_token_at_end=False,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        inputs['pretrained_inputs'] = pretrained_inputs
        masks = prepare_mask(pretrained_inputs)

        pretrained_inputs_hd = prepare_inputs_for_bert_xlnet_act_slot_value(
            act_inputs,
            act_slot_pairs,
            value_inps,
            value_outs,
            raw_labels,
            memory,
            opt.tokenizer,
            cls_token_at_end=False,
            cls_token=opt.tokenizer.cls_token,
            sep_token=opt.tokenizer.sep_token,
            cls_token_segment_id=0,
            pad_on_left=False,
            pad_token_segment_id=0,
            device=opt.device)
        tokens = pretrained_inputs['tokens']
        if opt.with_system_act:
            utt_lens = pretrained_inputs['utt_token_lens']
            extend_ids = [
                tokens[j][:utt_lens[j]].unsqueeze(0)
                for j in range(len(utt_lens))
            ]
        else:
            extend_ids = [tokens[j].unsqueeze(0) for j in range(len(raw_in))]
        pretrained_inputs_hd['enc_batch_extend_vocab_idx'] = extend_ids
        pretrained_inputs_hd['oov_lists'] = oov_lists = [
            [] for _ in range(batch_in.size(0))
        ]

        inputs['hd_inputs'] = pretrained_inputs_hd

        # forward
        batch_preds = model.decode_batch_tf_hd(inputs,
                                               masks,
                                               memory,
                                               opt.device,
                                               tokenizer=opt.tokenizer)

        # calculate performance
        batch_pred_classes = []
        batch_ids = []
        for pred_classes, gold, raw in zip(batch_preds, raw_labels, raw_in):

            # ontology filter
            if opt.ontology is not None:
                pred_classes = filter_informative(pred_classes, opt.ontology)
                gold = filter_informative(gold, opt.ontology)

            TP, FP, FN = update_f1(pred_classes, gold, TP, FP, FN)
            tot += 1
            if set(pred_classes) == set(gold):
                corr += 1

            batch_pred_classes.append(pred_classes)

            batch_ids.append(utt_id)
            utt_id += 1

            # keep intermediate results
            res_info = '%s\t<=>\t%s\t<=>\t%s\n' % (
                ' '.join(raw), ';'.join(pred_classes), ';'.join(gold))
            fp.write(res_info)
            if set(pred_classes) != set(gold):
                efp.write(res_info)
                err_cases.append((raw, pred_classes, gold))
            all_cases.append((raw, pred_classes, gold))

    # mean_loss = np.mean(losses)
    p, r, f = compute_f1(TP, FP, FN)
    acc = corr / tot * 100

    # err_analysis(err_cases)

    if opt.testing:
        return (p, r, f), acc, all_cases
    else:
        return (p, r, f), acc
Example #6
0
def train_epoch(model, data, opt, memory):
    '''Epoch operation in training phase'''

    model.train()
    opt.optimizer.zero_grad()

    TP, FP, FN = 0, 0, 0
    corr, tot = 0, 0
    losses = []
    total_loss_num_pairs = [[0., 0.], [0., 0.], [0., 0.]]

    for step, batch in enumerate(data):
        # prepare data
        batch_in, batch_pos, batch_score, \
            batch_sa, batch_sa_parent, batch_sa_sib, batch_sa_type, \
            batch_labels, raw_in, raw_sa, raw_labels, \
            act_labels, act_inputs, slot_labels, act_slot_pairs, value_inps, value_outs, \
            enc_batch_extend_vocab_idx, oov_lists = batch
        lens_utt = [len(utt) + 1 for utt in raw_in]  # +1 due to [CLS]
        lens_sysact = [len(seq) for seq in raw_sa]
        raw_lens = [len(utt) for utt in raw_in]

        # prepare inputs for BERT/XLNET
        inputs = {}
        if opt.with_system_act:
            batch_sa_pos = get_sequential_pos(batch_sa)
            pretrained_inputs = prepare_inputs_for_bert_xlnet_one_seq(
                raw_in,
                raw_lens,
                batch_pos,
                batch_score,
                raw_sa,
                lens_sysact,
                batch_sa_pos,
                batch_sa_parent,
                batch_sa_sib,
                batch_sa_type,
                opt.tokenizer,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        else:
            pretrained_inputs = prepare_inputs_for_bert_xlnet(
                raw_in,
                raw_lens,
                opt.tokenizer,
                batch_pos,
                batch_score,
                cls_token_at_end=False,
                cls_token=opt.tokenizer.cls_token,
                sep_token=opt.tokenizer.sep_token,
                cls_token_segment_id=0,
                pad_on_left=False,
                pad_token_segment_id=0,
                device=opt.device)
        inputs['pretrained_inputs'] = pretrained_inputs
        masks = prepare_mask(pretrained_inputs)

        pretrained_inputs_hd = prepare_inputs_for_bert_xlnet_act_slot_value(
            act_inputs,
            act_slot_pairs,
            value_inps,
            value_outs,
            raw_labels,
            memory,
            opt.tokenizer,
            cls_token_at_end=False,
            cls_token=opt.tokenizer.cls_token,
            sep_token=opt.tokenizer.sep_token,
            cls_token_segment_id=0,
            pad_on_left=False,
            pad_token_segment_id=0,
            device=opt.device)
        tokens = pretrained_inputs['tokens']
        if opt.with_system_act:
            utt_lens = pretrained_inputs['utt_token_lens']
            extend_ids = [
                tokens[j][:utt_lens[j]].unsqueeze(0)
                for j in range(len(utt_lens))
            ]
        else:
            extend_ids = [tokens[j].unsqueeze(0) for j in range(len(raw_in))]

        pretrained_inputs_hd['enc_batch_extend_vocab_idx'] = extend_ids
        pretrained_inputs_hd['oov_lists'] = oov_lists = [
            [] for _ in range(batch_in.size(0))
        ]

        inputs['hd_inputs'] = pretrained_inputs_hd

        # forward
        batch_preds = model(inputs, masks)
        act_scores, slot_scores, value_scores = batch_preds
        batch_gold_values = pretrained_inputs_hd['value_outs']

        batch_loss_num_pairs = [[0., 0.], [0., 0.], [0., 0.]]
        for i in range(len(oov_lists)):
            # act loss
            act_loss = opt.class_loss_function(act_scores[i],
                                               act_labels[i].unsqueeze(0))
            batch_loss_num_pairs[0][0] += act_loss
            batch_loss_num_pairs[0][1] += 1

            # NOTE: loss normalization for slot_loss & value_loss
            # slot loss
            if slot_scores[i] is not None:
                slot_loss = opt.class_loss_function(slot_scores[i],
                                                    slot_labels[i])
                batch_loss_num_pairs[1][0] += slot_loss / slot_scores[i].size(
                    0)
                batch_loss_num_pairs[1][1] += 1

            # value loss
            if value_scores[i] is not None:
                gold_values = batch_gold_values[i]
                sum_value_lens = gold_values.gt(0).sum().item()
                value_loss = opt.nll_loss_function(
                    value_scores[i].contiguous().view(
                        -1, opt.dec_word_vocab_size +
                        len(oov_lists[i]) * opt.with_ptr),
                    gold_values.contiguous().view(-1))
                batch_loss_num_pairs[2][0] += value_loss / sum_value_lens
                batch_loss_num_pairs[2][1] += 1

        total_loss = 0.
        loss_ratios = [1, 1, 1]
        for k, loss_num in enumerate(batch_loss_num_pairs):
            if loss_num[1] > 0:
                total_loss += (loss_num[0] / loss_num[1]) * loss_ratios[k]
        total_loss.backward()

        if (step + 1) % opt.n_accum_steps == 0:
            # clip gradient
            if opt.optim_choice.lower() != 'bertadam' and opt.max_norm > 0:
                params = list(
                    filter(lambda p: p.requires_grad,
                           list(model.parameters())))
                torch.nn.utils.clip_grad_norm_(params, opt.max_norm)

            # update parameters
            if opt.optim_choice.lower() in ['adam', 'bertadam']:
                opt.optimizer.step()
            elif opt.optim_choice.lower() == 'adamw':
                opt.optimizer.step()
                opt.scheduler.step()

            # clear gradients
            opt.optimizer.zero_grad()

        for i in range(len(batch_loss_num_pairs)):
            if isinstance(batch_loss_num_pairs[i][0], float):
                total_loss_num_pairs[i][0] += batch_loss_num_pairs[i][0]
            else:
                total_loss_num_pairs[i][0] += batch_loss_num_pairs[i][0].item()
            total_loss_num_pairs[i][1] += batch_loss_num_pairs[i][1]

        # calculate performance
        batch_preds = model.decode_batch_tf_hd(inputs,
                                               masks,
                                               memory,
                                               opt.device,
                                               tokenizer=opt.tokenizer)
        for pred_classes, gold in zip(batch_preds, raw_labels):
            TP, FP, FN = update_f1(pred_classes, gold, TP, FP, FN)
            tot += 1
            if set(pred_classes) == set(gold):
                corr += 1

    act_avg_loss = total_loss_num_pairs[0][0] / total_loss_num_pairs[0][1]
    slot_avg_loss = total_loss_num_pairs[1][0] / total_loss_num_pairs[1][1]
    value_avg_loss = total_loss_num_pairs[2][0] / total_loss_num_pairs[2][1]
    total_avg_loss = act_avg_loss + slot_avg_loss + value_avg_loss
    losses = (act_avg_loss, slot_avg_loss, value_avg_loss, total_avg_loss)
    p, r, f = compute_f1(TP, FP, FN)
    acc = corr / tot * 100

    return losses, (p, r, f), acc