Esempio n. 1
0
def build_slot_decoder(opt, config, emission_scorer, emb_log):
    transition_scorer = None
    if opt.decoder == 'sms':
        decoder = SequenceLabeler()
    elif opt.decoder == 'rule':
        decoder = RuleSequenceLabeler(config['id2label']['slot'])
    elif opt.decoder == 'crf':
        # logger.info('We only support back-off trans training now!')
        # Notice: only train back-off now
        trans_normalizer = build_scale_controller(name=opt.trans_normalizer)
        trans_scaler = build_scale_controller(
            name=opt.trans_scaler, kwargs=make_scaler_args(opt.trans_scaler, trans_normalizer, opt.trans_scale_r))
        if opt.transition == 'learn':
            transition_scorer = FewShotTransitionScorer(
                num_tags=config['num_tags']['slot'], normalizer=trans_normalizer, scaler=trans_scaler,
                r=opt.trans_r, backoff_init=opt.backoff_init)
        elif opt.transition == 'learn_with_label':
            label_trans_normalizer = build_scale_controller(name=opt.label_trans_normalizer)
            label_trans_scaler = build_scale_controller(name=opt.label_trans_scaler, kwargs=make_scaler_args(
                opt.label_trans_scaler, label_trans_normalizer, opt.label_trans_scale_r))
            transition_scorer = FewShotTransitionScorerFromLabel(
                num_tags=config['num_tags']['slot'], normalizer=trans_normalizer, scaler=trans_scaler,
                r=opt.trans_r, backoff_init=opt.backoff_init, label_scaler=label_trans_scaler)
        else:
            raise ValueError('Wrong choice of transition.')
        if opt.add_transition_rules and 'id2label' in config:  # 0 is [PAD] label id, here remove it.
            non_pad_id2label = copy.deepcopy(config['id2label']['slot']).__delitem__(0)
            print('before - non_pad_id2label: {}'.format(non_pad_id2label))
            for k, v in non_pad_id2label.items():
                # TODO: maybe error? `k - 1` rather than `v - 1`
                non_pad_id2label[k] = v - 1  # we 0 as [PAD] label id, here remove it.
            print('after - non_pad_id2label: {}'.format(non_pad_id2label))
            constraints = allowed_transitions(constraint_type='BIO', labels=non_pad_id2label)
        else:
            constraints = None
        decoder = ConditionalRandomField(
            num_tags=transition_scorer.num_tags, constraints=constraints)  # accurate tags
    else:
        raise TypeError('wrong component type')

    seq_laber = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler
    slot_model = seq_laber(opt=opt,
                           emission_scorer=emission_scorer,
                           decoder=decoder,
                           transition_scorer=transition_scorer,
                           config=config,
                           emb_log=emb_log)
    return slot_model
Esempio n. 2
0
def make_model(opt, config):
    """ Customize and build the few-shot learning model from components """

    ''' Build context_embedder '''
    if opt.context_emb == 'bert':
        context_embedder = BertSchemaContextEmbedder(opt=opt) if opt.use_schema else BertContextEmbedder(opt=opt)
    elif opt.context_emb == 'sep_bert':
        context_embedder = BertSchemaSeparateContextEmbedder(opt=opt) if opt.use_schema else \
            BertSeparateContextEmbedder(opt=opt)
    elif opt.context_emb == 'electra':
        context_embedder = ElectraSchemaContextEmbedder(opt=opt) if opt.use_schema else ElectraContextEmbedder(opt=opt)
    elif opt.context_emb == 'elmo':
        raise NotImplementedError
    elif opt.context_emb == 'glove':
        context_embedder = NormalContextEmbedder(opt=opt, num_token=len(opt.word2id))
        context_embedder.load_embedding()
    elif opt.context_emb == 'raw':
        context_embedder = NormalContextEmbedder(opt=opt, num_token=len(opt.word2id))
    else:
        raise TypeError('wrong component type')

    ''' Create log file to record testing data '''
    if opt.emb_log:
        emb_log = open(os.path.join(opt.output_dir, 'emb.log'), 'w')
        if 'id2label' in config:
            emb_log.write('id2label\t' + '\t'.join([str(k) + ':' + str(v) for k, v in config['id2label'].items()]) + '\n')
    else:
        emb_log = None

    '''Build emission scorer and similarity scorer '''
    # build scaler
    ems_normalizer = build_scale_controller(name=opt.emission_normalizer)
    ems_scaler = build_scale_controller(
        name=opt.emission_scaler, kwargs=make_scaler_args(opt.emission_scaler, ems_normalizer, opt.ems_scale_r))
    if opt.similarity == 'dot':
        sim_func = reps_dot
    elif opt.similarity == 'cosine':
        sim_func = reps_cosine_sim
    elif opt.similarity == 'l2':
        sim_func = reps_l2_sim
    else:
        raise TypeError('wrong component type')

    if opt.emission == 'mnet':
        similarity_scorer = MatchingSimilarityScorer(sim_func=sim_func, emb_log=emb_log)
        emission_scorer = MNetEmissionScorer(similarity_scorer, ems_scaler, opt.div_by_tag_num)
    elif opt.emission == 'proto':
        similarity_scorer = PrototypeSimilarityScorer(sim_func=sim_func, emb_log=emb_log)
        emission_scorer = PrototypeEmissionScorer(similarity_scorer, ems_scaler)
    elif opt.emission == 'proto_with_label':
        similarity_scorer = ProtoWithLabelSimilarityScorer(sim_func=sim_func, scaler=opt.ple_scale_r, emb_log=emb_log)
        emission_scorer = ProtoWithLabelEmissionScorer(similarity_scorer, ems_scaler)
    elif opt.emission == 'tapnet':
        # set num of anchors:
        # (1) if provided in config, use it (usually in load model case.)
        # (2) *3 is used to ensure enough anchors ( > num_tags of unseen domains )
        num_anchors = config['num_anchors'] if 'num_anchors' in config else config['num_tags'] * 3
        config['num_anchors'] = num_anchors
        anchor_dim = 256 if opt.context_emb == 'electra' else 768
        similarity_scorer = TapNetSimilarityScorer(
            sim_func=sim_func, num_anchors=num_anchors, mlp_out_dim=opt.tap_mlp_out_dim,
            random_init=opt.tap_random_init, random_init_r=opt.tap_random_init_r,
            mlp=opt.tap_mlp, emb_log=emb_log, tap_proto=opt.tap_proto, tap_proto_r=opt.tap_proto_r,
            anchor_dim=anchor_dim)
        emission_scorer = TapNetEmissionScorer(similarity_scorer, ems_scaler)
    else:
        raise TypeError('wrong component type')

    ''' Build decoder '''
    if opt.task == 'sl': # for sequence labeling
        if opt.decoder == 'sms':
            transition_scorer = None
            decoder = SequenceLabeler()
        elif opt.decoder == 'rule':
            transition_scorer = None
            decoder = RuleSequenceLabeler(config['id2label'])
        elif opt.decoder == 'crf':
            # Notice: only train back-off now
            trans_normalizer = build_scale_controller(name=opt.trans_normalizer)
            trans_scaler = build_scale_controller(
                name=opt.trans_scaler, kwargs=make_scaler_args(opt.trans_scaler, trans_normalizer, opt.trans_scale_r))
            if opt.transition == 'learn':
                transition_scorer = FewShotTransitionScorer(
                    num_tags=config['num_tags'], normalizer=trans_normalizer, scaler=trans_scaler,
                    r=opt.trans_r, backoff_init=opt.backoff_init)
            elif opt.transition == 'learn_with_label':
                label_trans_normalizer = build_scale_controller(name=opt.label_trans_normalizer)
                label_trans_scaler = build_scale_controller(name=opt.label_trans_scaler, kwargs=make_scaler_args(
                        opt.label_trans_scaler, label_trans_normalizer, opt.label_trans_scale_r))
                transition_scorer = FewShotTransitionScorerFromLabel(
                    num_tags=config['num_tags'], normalizer=trans_normalizer, scaler=trans_scaler,
                    r=opt.trans_r, backoff_init=opt.backoff_init, label_scaler=label_trans_scaler)
            else:
                raise ValueError('Wrong choice of transition.')
            if opt.add_transition_rules and 'id2label' in config:  # 0 is [PAD] label id, here remove it.
                non_pad_id2label = copy.deepcopy(config['id2label']).__delitem__(0)
                for k, v in non_pad_id2label.items():
                    non_pad_id2label[k] = v - 1  # we 0 as [PAD] label id, here remove it.
                constraints = allowed_transitions(constraint_type='BIO', labels=non_pad_id2label)
            else:
                constraints = None
            decoder = ConditionalRandomField(
                num_tags=transition_scorer.num_tags, constraints=constraints)  # accurate tags
        else:
            raise TypeError('wrong component type')
    elif opt.task == 'mlc':  # for multi-label text classification task
        grad_threshold = True if opt.threshold_type == 'learn' else False
        if opt.decoder == 'mlc':
            decoder = MultiLabelTextClassifier(opt.threshold, grad_threshold)
        elif opt.decoder == 'eamlc':
            decoder = EAMultiLabelTextClassifier(opt.threshold, grad_threshold)
        elif opt.decoder == 'msmlc':
            decoder = MetaStatsMultiLabelTextClassifier(opt.threshold, grad_threshold, meta_rate=opt.meta_rate,
                                                        ab_ea=opt.ab_ea)
        elif opt.decoder == 'krnmsmlc':
            map_dict = {
                "feature_map": opt.feature_map,
                "feature_num": opt.feature_num,
                "feature_map_dim": opt.feature_map_dim,
                "feature_map_act": opt.feature_map_act,
                "feature_map_layer_num": opt.feature_map_layer_num,
            }
            decoder = KRNMetaStatsMultiLabelTextClassifier(opt.threshold, grad_threshold, meta_rate=opt.meta_rate,
                                                           ab_ea=opt.ab_ea, kernel=opt.kernel, bandwidth=opt.bandwidth,
                                                           use_gold=opt.use_gold, learnable=opt.kernel_learnable,
                                                           map_dict=map_dict)
        else:
            raise TypeError('wrong component type')
    elif opt.task == 'sc':  # for single-label text classification task
        decoder = SingleLabelTextClassifier()
    else:
        raise TypeError('wrong task type')

    ''' Build the whole model '''
    if opt.task == 'sl':
        seq_labeler = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler
        model = seq_labeler(
            opt=opt,
            context_embedder=context_embedder,
            emission_scorer=emission_scorer,
            decoder=decoder,
            transition_scorer=transition_scorer,
            config=config,
            emb_log=emb_log
        )
    elif opt.task in ['sc', 'mlc']:
        text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier
        model = text_classifier(
            opt=opt,
            context_embedder=context_embedder,
            emission_scorer=emission_scorer,
            decoder=decoder,
            config=config,
            emb_log=emb_log
        )
    else:
        raise TypeError('wrong task type')
    return model
Esempio n. 3
0
def make_model(opt, config):
    """ Customize and build the few-shot learning model from components """

    ''' Create log file to record testing data '''
    emb_log = open(os.path.join(opt.output_dir, 'emb.log'), 'w') if opt.emb_log else None

    context_embedder = build_context_embedder(opt)

    '''Build emission scorer and similarity scorer '''
    ems_normalizer = build_scale_controller(name=opt.emission_normalizer)
    ems_scaler = build_scale_controller(
        name=opt.emission_scaler, kwargs=make_scaler_args(opt.emission_scaler, ems_normalizer, opt.ems_scale_r))

    if opt.similarity == 'dot':
        sim_func = reps_dot
    elif opt.similarity == 'cosine':
        sim_func = reps_cosine_sim
    elif opt.similarity == 'l2':
        sim_func = reps_l2_sim
    elif opt.similarity == 'relation':
        emb_dim = 768 if opt.context_emb in ['bert', 'sep_bert'] else 256
        pair_emb_dim = 2 * emb_dim
        sim_func = RelationNetSim(opt=opt, input_size=pair_emb_dim, hidden_size=opt.relation_hidden_size)
    else:
        raise TypeError('wrong component type')

    intent_emission_scorer = build_emission_scorer(opt, config, ems_scaler, sim_func, emb_log, 'intent')
    slot_emission_scorer = build_emission_scorer(opt, config, ems_scaler, sim_func, emb_log, 'slot')

    ''' Build decoder '''
    intent_decoder = None
    slot_decoder = None
    if opt.task in ['slot_filling', 'slu']:
        slot_decoder = build_slot_decoder(opt, config, slot_emission_scorer, emb_log)

    if opt.task in ['intent', 'slu']:
        intent_decoder = build_intent_decoder(opt, config, intent_emission_scorer, emb_log)

    ''' Build the whole model '''
    opt.id2label = config['id2label']
    if opt.slu_model_type == 'simple':
        few_shot_slu = SchemaFewShotSLU if opt.use_schema else FewShotSLU
    elif opt.slu_model_type == 'split_metric':
        few_shot_slu = SplitMetricFewShotSLU
    elif opt.slu_model_type == 'emission_merge_intent':
        few_shot_slu = EmissionMergeIntentFewShotSLU
    elif opt.slu_model_type == 'emission_merge_slot':
        few_shot_slu = EmissionMergeSlotFewShotSLU
    elif opt.slu_model_type == 'emission_merge_iteration':
        few_shot_slu = EmissionMergeIterationFewShotSLU
    else:
        raise NotImplementedError
    model = few_shot_slu(
        opt=opt,
        context_embedder=context_embedder,
        slot_decoder=slot_decoder,
        intent_decoder=intent_decoder,
        config=config,
        emb_log=emb_log
    )

    return model
def make_model(opt, config):
    """ Customize and build the few-shot learning model from components """
    ''' Build context_embedder '''
    if opt.context_emb == 'bert':
        context_embedder = BertSchemaContextEmbedder(
            opt=opt) if opt.use_schema else BertContextEmbedder(opt=opt)
    elif opt.context_emb == 'sep_bert':
        context_embedder = BertSchemaSeparateContextEmbedder(opt=opt) if opt.use_schema else \
            BertSeparateContextEmbedder(opt=opt)
    elif opt.context_emb == 'electra':
        context_embedder = ElectraSchemaContextEmbedder(
            opt=opt) if opt.use_schema else ElectraContextEmbedder(opt=opt)
    elif opt.context_emb == 'elmo':
        raise NotImplementedError
    elif opt.context_emb == 'glove':
        context_embedder = NormalContextEmbedder(opt=opt,
                                                 num_token=len(opt.word2id))
        context_embedder.load_embedding()
    elif opt.context_emb == 'raw':
        context_embedder = NormalContextEmbedder(opt=opt,
                                                 num_token=len(opt.word2id))
    else:
        raise TypeError('wrong component type')
    ''' Create log file to record testing data '''
    if opt.emb_log:
        emb_log = open(os.path.join(opt.output_dir, 'emb.log'), 'w')
        if 'id2label_map' in config:
            emb_log.write('id2label_map\t' + '\t'.join([
                str(k) + ':' + str(v)
                for k, v in config['id2label_map'].items()
            ]) + '\n')
    else:
        emb_log = None
    '''Build emission scorer and similarity scorer '''
    # build scaler
    ems_normalizer = build_scale_controller(name=opt.emission_normalizer)
    ems_scaler = build_scale_controller(name=opt.emission_scaler,
                                        kwargs=make_scaler_args(
                                            opt.emission_scaler,
                                            ems_normalizer, opt.ems_scale_r))
    if opt.similarity == 'dot':
        sim_func = reps_dot
    elif opt.similarity == 'cosine':
        sim_func = reps_cosine_sim
    elif opt.similarity == 'l2':
        sim_func = reps_l2_sim
    else:
        raise TypeError('wrong component type')

    assert len(opt.emission) == len(
        opt.task), "the emission list should match with task list"

    emission_scorer_map = {}
    for task, emission in zip(opt.task, opt.emission):
        if emission == 'mnet':
            similarity_scorer = MatchingSimilarityScorer(sim_func=sim_func,
                                                         emb_log=emb_log)
            emission_scorer = MNetEmissionScorer(similarity_scorer, ems_scaler,
                                                 opt.div_by_tag_num)
        elif emission == 'proto':
            similarity_scorer = PrototypeSimilarityScorer(sim_func=sim_func,
                                                          emb_log=emb_log)
            emission_scorer = PrototypeEmissionScorer(similarity_scorer,
                                                      ems_scaler)
        elif emission == 'proto_with_label':
            similarity_scorer = ProtoWithLabelSimilarityScorer(
                sim_func=sim_func, scaler=opt.ple_scale_r, emb_log=emb_log)
            emission_scorer = ProtoWithLabelEmissionScorer(
                similarity_scorer, ems_scaler)
        elif emission == 'tapnet':
            # set num of anchors:
            # (1) if provided in config, use it (usually in load model case.)
            # (2) *3 is used to ensure enough anchors ( > num_tags of unseen domains )
            print('config: {}'.format(config))
            num_anchors = config[
                'num_anchors'] if 'num_anchors' in config else config[
                    'num_tags'] * 3
            config['num_anchors'] = num_anchors
            anchor_dim = 256 if opt.context_emb == 'electra' else 768
            similarity_scorer = TapNetSimilarityScorer(
                sim_func=sim_func,
                num_anchors=num_anchors,
                mlp_out_dim=opt.tap_mlp_out_dim,
                random_init=opt.tap_random_init,
                random_init_r=opt.tap_random_init_r,
                mlp=opt.tap_mlp,
                emb_log=emb_log,
                tap_proto=opt.tap_proto,
                tap_proto_r=opt.tap_proto_r,
                anchor_dim=anchor_dim)
            emission_scorer = TapNetEmissionScorer(similarity_scorer,
                                                   ems_scaler)
        else:
            raise TypeError('wrong component type')
        emission_scorer_map[task] = emission_scorer
    ''' Build decoder '''
    model_map = {}
    transition_scorer = None
    if 'sl' in opt.task:  # for sequence labeling
        if opt.decoder == 'sms':
            decoder = SequenceLabeler()
        elif opt.decoder == 'rule':
            decoder = RuleSequenceLabeler(config['id2label'])
        elif opt.decoder == 'crf':
            # logger.info('We only support back-off trans training now!')
            # Notice: only train back-off now
            trans_normalizer = build_scale_controller(
                name=opt.trans_normalizer)
            trans_scaler = build_scale_controller(
                name=opt.trans_scaler,
                kwargs=make_scaler_args(opt.trans_scaler, trans_normalizer,
                                        opt.trans_scale_r))
            if opt.transition == 'learn':
                transition_scorer = FewShotTransitionScorer(
                    num_tags=config['num_tags'],
                    normalizer=trans_normalizer,
                    scaler=trans_scaler,
                    r=opt.trans_r,
                    backoff_init=opt.backoff_init)
            elif opt.transition == 'learn_with_label':
                label_trans_normalizer = build_scale_controller(
                    name=opt.label_trans_normalizer)
                label_trans_scaler = build_scale_controller(
                    name=opt.label_trans_scaler,
                    kwargs=make_scaler_args(opt.label_trans_scaler,
                                            label_trans_normalizer,
                                            opt.label_trans_scale_r))
                transition_scorer = FewShotTransitionScorerFromLabel(
                    num_tags=config['num_tags'],
                    normalizer=trans_normalizer,
                    scaler=trans_scaler,
                    r=opt.trans_r,
                    backoff_init=opt.backoff_init,
                    label_scaler=label_trans_scaler)
            else:
                raise ValueError('Wrong choice of transition.')
            if opt.add_transition_rules and 'id2label' in config:  # 0 is [PAD] label id, here remove it.
                non_pad_id2label = copy.deepcopy(
                    config['id2label']).__delitem__(0)
                for k, v in non_pad_id2label.items():
                    non_pad_id2label[
                        k] = v - 1  # we 0 as [PAD] label id, here remove it.
                constraints = allowed_transitions(constraint_type='BIO',
                                                  labels=non_pad_id2label)
            else:
                constraints = None
            decoder = ConditionalRandomField(
                num_tags=transition_scorer.num_tags,
                constraints=constraints)  # accurate tags
        else:
            raise TypeError('wrong component type')

        seq_laber = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler
        model_map['sl'] = seq_laber(opt=opt,
                                    context_embedder=context_embedder,
                                    emission_scorer=emission_scorer_map['sl'],
                                    decoder=decoder,
                                    transition_scorer=transition_scorer,
                                    config=config,
                                    emb_log=emb_log)

    if 'sc' in opt.task:  # for single-label text classification task
        decoder = SingleLabelTextClassifier()
        text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier
        model_map['sc'] = text_classifier(
            opt=opt,
            context_embedder=context_embedder,
            emission_scorer=emission_scorer_map['sc'],
            decoder=decoder,
            config=config,
            emb_log=emb_log)
    ''' Build the whole model '''
    few_shot_learner = SchemaFewShotLearner if opt.use_schema else FewShotLearner
    model = few_shot_learner(opt=opt,
                             context_embedder=context_embedder,
                             model_map=model_map,
                             config=config,
                             emb_log=emb_log)

    return model