Ejemplo n.º 1
0
def get_init_params(config, in2i, i2in, t2i, automata_path):

    dset = config.dataset
    if 'SMS' in config.dataset:
        dset = 'SMS'

    pretrained_embed = load_glove_embed('../data/{}/'.format(dset), config.embed_dim)
    automata_dicts = load_pkl(automata_path)
    automata = automata_dicts['automata']

    V_embed, D1, D2 = automata_dicts['V'], automata_dicts['D1'], automata_dicts['D2']
    wildcard_mat, language = automata_dicts['wildcard_mat'], automata_dicts['language']

    n_vocab, rank = V_embed.shape
    n_state, _ = D1.shape
    print("DFA states: {}".format(n_state))
    _, embed_dim = pretrained_embed.shape
    if dset == 'ATIS':
        mat, bias = create_mat_and_bias_with_empty_ATIS(automata, in2i=in2i, i2in=i2in,)
    elif dset== 'TREC':
        mat, bias = create_mat_and_bias_with_empty_TREC(automata, in2i=in2i, i2in=i2in,)
    elif dset == 'SMS':
        mat, bias = create_mat_and_bias_with_empty_SMS(automata, in2i=in2i, i2in=i2in,)

    # for padding
    pretrain_embed_extend = np.append(pretrained_embed, np.zeros((1, config.embed_dim), dtype=np.float), axis=0)
    V_embed_extend = np.append(V_embed, np.zeros((1, rank), dtype=np.float), axis=0)

    # creating language mask for regularization
    n_vocab_extend, _ = V_embed_extend.shape
    language_mask = torch.ones(n_vocab_extend)
    language_mask[[t2i[i] for i in language]] = 0

    # for V_embed_weighted mask and extend the wildcard mat to the right dimension
    S, _ = wildcard_mat.shape
    wildcard_mat_origin_extend = np.zeros((S + config.additional_state, S + config.additional_state))
    wildcard_mat_origin_extend[:S, :S] = wildcard_mat
    wildcard_mat_origin_extend = torch.from_numpy(wildcard_mat_origin_extend).float()
    if torch.cuda.is_available():
        language_mask = language_mask.cuda()
        wildcard_mat_origin_extend = wildcard_mat_origin_extend.cuda()

    if config.normalize_automata != 'none':

        D1_avg = get_average(D1, config.normalize_automata)
        D2_avg = get_average(D2, config.normalize_automata)
        V_embed_extend_avg = get_average(V_embed_extend, config.normalize_automata)
        factor = np.float_power(D1_avg* D2_avg* V_embed_extend_avg, 1/3)
        print(factor)
        print(D1_avg)
        print(D2_avg)
        print(V_embed_extend_avg)

        D1 = D1 * (factor / D1_avg)
        D2 = D2 * (factor / D2_avg)
        V_embed_extend = V_embed_extend * (factor / V_embed_extend_avg)

    return V_embed_extend, pretrain_embed_extend, mat, bias, D1, D2, language_mask, language, wildcard_mat, wildcard_mat_origin_extend
Ejemplo n.º 2
0
def train_marry_up(args):

    assert args.additional_state == 0
    if args.model_type == 'KnowledgeDistill':
        assert args.marryup_type == 'none'
    if args.model_type == 'PR':
        assert args.marryup_type == 'none'

    all_pred_train, all_pred_dev, all_pred_test, all_out_train, all_out_dev, all_out_test = PredictByRE(
        args)

    logger = Logger()
    # config = Config_MarryUp(args)

    dset = load_classification_dataset(args.dataset)
    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, _, train_lengths = pad_dataset(query_train, args,
                                                t2i['<pad>'])
    dev_query, _, dev_lengths = pad_dataset(query_dev, args, t2i['<pad>'])
    test_query, _, test_lengths = pad_dataset(query_test, args, t2i['<pad>'])

    shots = int(len(train_query) * args.train_portion)
    if args.use_unlabel:
        intent_data_train = MarryUpIntentBatchDatasetUtilizeUnlabel(
            train_query, train_lengths, intent_train, all_pred_train,
            all_out_train, shots)
    elif args.train_portion == 0:
        # special case when train portion==0 and do not use unlabel data, should have no data
        intent_data_train = None
    else:
        intent_data_train = MarryUpIntentBatchDataset(train_query,
                                                      train_lengths,
                                                      intent_train,
                                                      all_out_train, shots)

    # should have no/few dev data in low-resource setting
    if args.train_portion == 0:
        intent_data_dev = None
    elif args.train_portion <= 0.01:
        intent_data_dev = MarryUpIntentBatchDataset(dev_query, dev_lengths,
                                                    intent_dev, all_out_dev,
                                                    shots)
    else:
        intent_data_dev = MarryUpIntentBatchDataset(
            dev_query,
            dev_lengths,
            intent_dev,
            all_out_dev,
        )
    intent_data_test = MarryUpIntentBatchDataset(test_query, test_lengths,
                                                 intent_test, all_out_test)

    print('len train dataset {}'.format(
        len(intent_data_train) if intent_data_train else 0))
    print('len dev dataset {}'.format(
        len(intent_data_dev) if intent_data_dev else 0))
    print('len test dataset {}'.format(len(intent_data_test)))

    intent_dataloader_train = DataLoader(
        intent_data_train, batch_size=args.bz) if intent_data_train else None
    intent_dataloader_dev = DataLoader(
        intent_data_dev, batch_size=args.bz) if intent_data_dev else None
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    pretrained_embed = load_glove_embed('../data/{}/'.format(args.dataset),
                                        args.embed_dim)
    if args.random_embed:
        pretrained_embed = np.random.random(pretrained_embed.shape)

    # for padding
    pretrain_embed_extend = np.append(pretrained_embed,
                                      np.zeros((1, args.embed_dim),
                                               dtype=np.float),
                                      axis=0)

    model = IntentMarryUp(
        pretrained_embed=pretrain_embed_extend,
        config=args,
        label_size=len(in2i),
    )

    criterion = torch.nn.CrossEntropyLoss()
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=0)
    if args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=0)

    if torch.cuda.is_available():
        model = model.cuda()

    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print('ALL TRAINABLE PARAMETERS: {}'.format(pytorch_total_params))
    acc_dev_init, avg_loss_dev_init, p, r = val_marry(model,
                                                      intent_dataloader_dev, 0,
                                                      'DEV', args, logger)
    # TEST
    acc_test_init, avg_loss_test_init, p, r = val_marry(
        model, intent_dataloader_test, 0, 'TEST', args, logger)

    best_dev_acc = acc_dev_init
    counter = 0
    best_dev_model = deepcopy(model)
    # when no training data, just run a test.
    if not intent_dataloader_train: args.epoch = 0

    for epoch in range(1, args.epoch + 1):
        avg_loss = 0
        acc = 0

        pbar_train = tqdm(intent_dataloader_train)
        pbar_train.set_description("TRAIN EPOCH {}".format(epoch))

        model.train()
        for batch in pbar_train:

            optimizer.zero_grad()

            x = batch['x']
            label = batch['i'].view(-1)
            lengths = batch['l']
            re_tag = batch['re']

            if torch.cuda.is_available():
                x = x.cuda()
                lengths = lengths.cuda()
                label = label.cuda()
                re_tag = re_tag.cuda()

            scores = model(x, lengths, re_tag)

            loss_cross_entropy = criterion(scores, label)

            if args.model_type == 'MarryUp':
                loss = loss_cross_entropy

            elif args.model_type == 'KnowledgeDistill':
                softmax_scores = torch.log_softmax(scores, 1)
                softmax_re_tag_teacher = torch.softmax(re_tag, 1)
                loss_KL = torch.nn.KLDivLoss()(softmax_scores,
                                               softmax_re_tag_teacher)
                loss = loss_cross_entropy * args.l1 + loss_KL * (
                    1 - args.l1
                )  # in KD, l1 stands for the alpha controlling to learn from true / imitate teacher

            elif args.model_type == 'PR':
                log_softmax_scores = torch.log_softmax(scores, 1)
                softmax_scores = torch.softmax(scores, 1)
                product_term = torch.exp(
                    re_tag - 1
                ) * args.l2  #in PR, l2 stands for the regularization term, higher l2, harder rule constraint
                teacher_score = torch.mul(softmax_scores, product_term)
                softmax_teacher = torch.softmax(teacher_score, 1)
                loss_KL = torch.nn.KLDivLoss()(log_softmax_scores,
                                               softmax_teacher)
                loss = loss_cross_entropy * args.l1 + loss_KL * (
                    1 - args.l1
                )  # in PR, l1 stands for the alpha controlling to learn from true / imitate teacher

            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

            acc += (scores.argmax(1) == label).sum().item()

            pbar_train.set_postfix_str(
                "{} - total right: {}, total loss: {}".format(
                    'TRAIN', acc, loss))

        acc = acc / len(intent_data_train)
        avg_loss = avg_loss / len(intent_data_train)
        # print("{} Epoch: {} | ACC: {}, LOSS: {}".format('TRAIN', epoch, acc, avg_loss))
        logger.add("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))

        # DEV
        acc_dev, avg_loss_dev, p, r = val_marry(model, intent_dataloader_dev,
                                                epoch, 'DEV', args, logger)

        counter += 1  # counter for early stopping

        if (acc_dev is None) or (acc_dev > best_dev_acc):
            counter = 0
            best_dev_acc = acc_dev
            best_dev_model = deepcopy(model)

        if counter > args.early_stop:
            break

    best_dev_test_acc, avg_loss_test, best_dev_test_p, best_dev_test_r \
        = val_marry(best_dev_model, intent_dataloader_dev, epoch, 'TEST', args, logger)

    return acc_dev_init, acc_test_init, best_dev_acc, best_dev_test_acc, best_dev_test_p, best_dev_test_r, logger.record