예제 #1
0
    def valid_on_epoch(self, epoch, filename, memory):

        self.logger.info("Epoch {:02} {} begins validing ...................".format(epoch, self.tag))

        self.model.eval()

        start_time = time.time()
        score = Fscore(self.tag)

        datas = SLUDataset.read_file(filename, memory)

        for pair in datas:

            cnet = pair['cnet']
            class_string = pair['label']

            gold_classes = SLUDataset.class_info(class_string)
            pred_classes = decode_slu(self.model, cnet, memory, self.cuda)

            score.update_tp_fp_fn(pred_classes, gold_classes)

        fscore = score.output_fscore(self.logger, epoch)

        elapsed_time = time.time() - start_time
        self.logger.info("Epoch {:02} {} ends validing elapsed_time: {:6.0f}s".format(
            epoch, self.tag, elapsed_time)
        )
        self.logger.info('*****************************************************')

        return fscore
예제 #2
0
def error(opt):

    opt.experiment = os.path.join(root_dir, opt.experiment)
    opt.load_chkpt = os.path.join(opt.experiment, opt.save_model)
    opt.test_file = os.path.join(opt.data_root, opt.test_file)
    opt.save_file = os.path.join(opt.experiment, 'error.json')

    # Model loading
    model = make_model(opt)
    chkpt = torch.load(opt.load_chkpt,
                       map_location=lambda storage, log: storage)
    model.load_state_dict(chkpt)
    if opt.deviceid >= 0:
        model = model.cuda()
    print(model)
    # ====== *********************** ================
    model.eval()
    # ===============================================

    # decode
    print('Decoding ...')
    if opt.task == 'act':
        datas = ActDataset.read_file(opt.test_file, opt.memory)
    elif opt.task == 'slot':
        datas = SlotDataset.read_file(opt.test_file, opt.memory)
    elif opt.task == 'value':
        datas = ValueDataset.read_file(opt.test_file, opt.memory)
    elif opt.task == 'slu':
        datas = SLUDataset.read_file(opt.test_file, opt.memory)

    dic = {'pairs': []}
    for pair in datas:
        cnet = pair['cnet']
        class_string = pair['label']
        if opt.task == 'act':
            gold_classes = ActDataset.class_info(class_string)
            pred_classes = decode_act(model, cnet, opt.memory, opt.cuda)
        elif opt.task == 'slot':
            gold_classes = SlotDataset.class_info(class_string)
            pred_classes = decode_slot(model, cnet, class_string, opt.memory,
                                       opt.cuda)
        elif opt.task == 'value':
            gold_classes = ValueDataset.class_info(class_string)
            pred_classes = decode_value(model, cnet, class_string, opt.memory,
                                        opt.cuda)
        elif opt.task == 'slu':
            gold_classes = SLUDataset.class_info(class_string)
            pred_classes = decode_slu(model, cnet, opt.memory, opt.cuda)
        gold_class = ';'.join(sorted(gold_classes))
        pred_class = ';'.join(sorted(pred_classes))
        if gold_class != pred_class:
            pr = {'cnet': cnet, 'label': gold_class, 'pred': pred_class}
            dic['pairs'].append(pr)

    string = json.dumps(dic, sort_keys=True, indent=4, separators=(',', ';'))
    with open(opt.save_file, 'w') as f:
        f.write(string)
    print('Decode results saved in {}'.format(opt.save_file))
예제 #3
0
def error(opt):

    opt.experiment = os.path.join(root_dir, opt.experiment)
    opt.load_chkpt = os.path.join(opt.experiment, opt.save_model)
    opt.test_file = os.path.join(opt.data_root, opt.test_file)
    opt.save_file = os.path.join(opt.experiment, 'error.info')

    # Model loading
    model = make_model(opt)
    chkpt = torch.load(opt.load_chkpt,
                       map_location=lambda storage, log: storage)
    model.load_state_dict(chkpt)
    if opt.deviceid >= 0:
        model = model.cuda()
    print(model)
    # ====== *********************** ================
    model.eval()
    # ===============================================

    # decode
    print('Decoding ...')
    g = open(opt.save_file, 'w')
    if opt.task == 'act':
        lines = ActDataset.read_file(opt.test_file)
    elif opt.task == 'slot':
        lines = SlotDataset.read_file(opt.test_file)
    elif opt.task == 'value':
        lines = ValueDataset.read_file(opt.test_file)
    elif opt.task == 'slu':
        lines = SLUDataset.read_file(opt.test_file)

    for (utterance, class_string) in lines:
        if opt.task == 'act':
            gold_classes = ActDataset.class_info(class_string)
            pred_classes = decode_act(model, utterance, opt.memory, opt.cuda)
        elif opt.task == 'slot':
            gold_classes = SlotDataset.class_info(class_string)
            pred_classes = decode_slot(model, utterance, class_string,
                                       opt.memory, opt.cuda)
        elif opt.task == 'value':
            gold_classes = ValueDataset.class_info(class_string)
            pred_classes = decode_value(model, utterance, class_string,
                                        opt.memory, opt.cuda)
        elif opt.task == 'slu':
            gold_classes = SLUDataset.class_info(class_string)
            pred_classes = decode_slu(model, utterance, opt.memory, opt.cuda)
        gold_class = ';'.join(sorted(gold_classes))
        pred_class = ';'.join(sorted(pred_classes))
        if gold_class != pred_class:
            g.write('{}\t<=>\t{}\t<=>\t{}\n'.format(utterance, gold_class,
                                                    pred_class))
    g.close()
    print('Decode results saved in {}'.format(opt.save_file))