示例#1
0
文件: train.py 项目: zeta1999/RE2RNN
def train_marry_up(args):

    assert args.additional_state == 0
    if args.model_type == 'KnowledgeDistill':
        assert args.marryup_type == 'none'
    if args.model_type == 'PR':
        assert args.marryup_type == 'none'

    all_pred_train, all_pred_dev, all_pred_test, all_out_train, all_out_dev, all_out_test = PredictByRE(
        args)

    logger = Logger()
    # config = Config_MarryUp(args)

    dset = load_classification_dataset(args.dataset)
    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, _, train_lengths = pad_dataset(query_train, args,
                                                t2i['<pad>'])
    dev_query, _, dev_lengths = pad_dataset(query_dev, args, t2i['<pad>'])
    test_query, _, test_lengths = pad_dataset(query_test, args, t2i['<pad>'])

    shots = int(len(train_query) * args.train_portion)
    if args.use_unlabel:
        intent_data_train = MarryUpIntentBatchDatasetUtilizeUnlabel(
            train_query, train_lengths, intent_train, all_pred_train,
            all_out_train, shots)
    elif args.train_portion == 0:
        # special case when train portion==0 and do not use unlabel data, should have no data
        intent_data_train = None
    else:
        intent_data_train = MarryUpIntentBatchDataset(train_query,
                                                      train_lengths,
                                                      intent_train,
                                                      all_out_train, shots)

    # should have no/few dev data in low-resource setting
    if args.train_portion == 0:
        intent_data_dev = None
    elif args.train_portion <= 0.01:
        intent_data_dev = MarryUpIntentBatchDataset(dev_query, dev_lengths,
                                                    intent_dev, all_out_dev,
                                                    shots)
    else:
        intent_data_dev = MarryUpIntentBatchDataset(
            dev_query,
            dev_lengths,
            intent_dev,
            all_out_dev,
        )
    intent_data_test = MarryUpIntentBatchDataset(test_query, test_lengths,
                                                 intent_test, all_out_test)

    print('len train dataset {}'.format(
        len(intent_data_train) if intent_data_train else 0))
    print('len dev dataset {}'.format(
        len(intent_data_dev) if intent_data_dev else 0))
    print('len test dataset {}'.format(len(intent_data_test)))

    intent_dataloader_train = DataLoader(
        intent_data_train, batch_size=args.bz) if intent_data_train else None
    intent_dataloader_dev = DataLoader(
        intent_data_dev, batch_size=args.bz) if intent_data_dev else None
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    pretrained_embed = load_glove_embed('../data/{}/'.format(args.dataset),
                                        args.embed_dim)
    if args.random_embed:
        pretrained_embed = np.random.random(pretrained_embed.shape)

    # for padding
    pretrain_embed_extend = np.append(pretrained_embed,
                                      np.zeros((1, args.embed_dim),
                                               dtype=np.float),
                                      axis=0)

    model = IntentMarryUp(
        pretrained_embed=pretrain_embed_extend,
        config=args,
        label_size=len(in2i),
    )

    criterion = torch.nn.CrossEntropyLoss()
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=0)
    if args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=0)

    if torch.cuda.is_available():
        model = model.cuda()

    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print('ALL TRAINABLE PARAMETERS: {}'.format(pytorch_total_params))
    acc_dev_init, avg_loss_dev_init, p, r = val_marry(model,
                                                      intent_dataloader_dev, 0,
                                                      'DEV', args, logger)
    # TEST
    acc_test_init, avg_loss_test_init, p, r = val_marry(
        model, intent_dataloader_test, 0, 'TEST', args, logger)

    best_dev_acc = acc_dev_init
    counter = 0
    best_dev_model = deepcopy(model)
    # when no training data, just run a test.
    if not intent_dataloader_train: args.epoch = 0

    for epoch in range(1, args.epoch + 1):
        avg_loss = 0
        acc = 0

        pbar_train = tqdm(intent_dataloader_train)
        pbar_train.set_description("TRAIN EPOCH {}".format(epoch))

        model.train()
        for batch in pbar_train:

            optimizer.zero_grad()

            x = batch['x']
            label = batch['i'].view(-1)
            lengths = batch['l']
            re_tag = batch['re']

            if torch.cuda.is_available():
                x = x.cuda()
                lengths = lengths.cuda()
                label = label.cuda()
                re_tag = re_tag.cuda()

            scores = model(x, lengths, re_tag)

            loss_cross_entropy = criterion(scores, label)

            if args.model_type == 'MarryUp':
                loss = loss_cross_entropy

            elif args.model_type == 'KnowledgeDistill':
                softmax_scores = torch.log_softmax(scores, 1)
                softmax_re_tag_teacher = torch.softmax(re_tag, 1)
                loss_KL = torch.nn.KLDivLoss()(softmax_scores,
                                               softmax_re_tag_teacher)
                loss = loss_cross_entropy * args.l1 + loss_KL * (
                    1 - args.l1
                )  # in KD, l1 stands for the alpha controlling to learn from true / imitate teacher

            elif args.model_type == 'PR':
                log_softmax_scores = torch.log_softmax(scores, 1)
                softmax_scores = torch.softmax(scores, 1)
                product_term = torch.exp(
                    re_tag - 1
                ) * args.l2  #in PR, l2 stands for the regularization term, higher l2, harder rule constraint
                teacher_score = torch.mul(softmax_scores, product_term)
                softmax_teacher = torch.softmax(teacher_score, 1)
                loss_KL = torch.nn.KLDivLoss()(log_softmax_scores,
                                               softmax_teacher)
                loss = loss_cross_entropy * args.l1 + loss_KL * (
                    1 - args.l1
                )  # in PR, l1 stands for the alpha controlling to learn from true / imitate teacher

            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

            acc += (scores.argmax(1) == label).sum().item()

            pbar_train.set_postfix_str(
                "{} - total right: {}, total loss: {}".format(
                    'TRAIN', acc, loss))

        acc = acc / len(intent_data_train)
        avg_loss = avg_loss / len(intent_data_train)
        # print("{} Epoch: {} | ACC: {}, LOSS: {}".format('TRAIN', epoch, acc, avg_loss))
        logger.add("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))

        # DEV
        acc_dev, avg_loss_dev, p, r = val_marry(model, intent_dataloader_dev,
                                                epoch, 'DEV', args, logger)

        counter += 1  # counter for early stopping

        if (acc_dev is None) or (acc_dev > best_dev_acc):
            counter = 0
            best_dev_acc = acc_dev
            best_dev_model = deepcopy(model)

        if counter > args.early_stop:
            break

    best_dev_test_acc, avg_loss_test, best_dev_test_p, best_dev_test_r \
        = val_marry(best_dev_model, intent_dataloader_dev, epoch, 'TEST', args, logger)

    return acc_dev_init, acc_test_init, best_dev_acc, best_dev_test_acc, best_dev_test_p, best_dev_test_r, logger.record
示例#2
0
文件: train.py 项目: zeta1999/RE2RNN
def train_fsa_rnn(args, paths):
    logger = Logger()

    # config = Config_Integrate(args)

    dset = load_classification_dataset(args.dataset)
    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, train_query_inverse, train_lengths = pad_dataset(
        query_train, args, t2i['<pad>'])
    dev_query, dev_query_inverse, dev_lengths = pad_dataset(
        query_dev, args, t2i['<pad>'])
    test_query, test_query_inverse, test_lengths = pad_dataset(
        query_test, args, t2i['<pad>'])
    shots = int(len(train_query) * args.train_portion)
    if args.use_unlabel:
        all_pred_train, all_pred_dev, all_pred_test, all_out_train, all_out_dev, all_out_test = PredictByRE(
            args)
        intent_data_train = ATISIntentBatchDatasetUtilizeUnlabel(
            train_query, train_query_inverse, train_lengths, intent_train,
            all_pred_train, all_out_train, shots)
    elif args.train_portion == 0:
        # special case when train portion==0 and do not use unlabel data, should have no data
        intent_data_train = None
    else:
        intent_data_train = ATISIntentBatchDatasetBidirection(
            train_query, train_query_inverse, train_lengths, intent_train,
            shots)

    # should have no/few dev data in low-resource setting
    if args.train_portion == 0:
        intent_data_dev = None
    elif args.train_portion <= 0.01:
        intent_data_dev = ATISIntentBatchDatasetBidirection(
            dev_query, dev_query_inverse, dev_lengths, intent_dev, shots)
    else:
        intent_data_dev = ATISIntentBatchDatasetBidirection(
            dev_query, dev_query_inverse, dev_lengths, intent_dev)
    intent_data_test = ATISIntentBatchDatasetBidirection(
        test_query,
        test_query_inverse,
        test_lengths,
        intent_test,
    )

    intent_dataloader_train = DataLoader(
        intent_data_train, batch_size=args.bz) if intent_data_train else None
    intent_dataloader_dev = DataLoader(
        intent_data_dev, batch_size=args.bz) if intent_data_dev else None
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    print('len train dataset {}'.format(
        len(intent_data_train) if intent_data_train else 0))
    print('len dev dataset {}'.format(
        len(intent_data_dev) if intent_data_dev else 0))
    print('len test dataset {}'.format(len(intent_data_test)))
    print('num labels: {}'.format(len(in2i)))
    print('num vocabs: {}'.format(len(t2i)))

    forward_params = dict()
    forward_params['V_embed_extend'], forward_params['pretrain_embed_extend'], forward_params['mat'], forward_params['bias'], \
    forward_params['D1'], forward_params['D2'], forward_params['language_mask'], forward_params['language'], forward_params['wildcard_mat'], \
    forward_params['wildcard_mat_origin_extend'] = \
        get_init_params(args, in2i, i2in, t2i, paths[0])

    if args.bidirection:
        backward_params = dict()
        backward_params['V_embed_extend'], backward_params['pretrain_embed_extend'], backward_params['mat'], backward_params['bias'], \
        backward_params['D1'], backward_params['D2'], backward_params['language_mask'], backward_params['language'], backward_params['wildcard_mat'], \
        backward_params['wildcard_mat_origin_extend'] = \
            get_init_params(args, in2i, i2in, t2i, paths[1])

    # get h1 for FSAGRU
    h1_forward = None
    h1_backward = None
    if args.farnn == 1:
        args.farnn = 0
        temp_model = FSARNNIntegrateEmptyStateSaperateGRU(
            pretrained_embed=forward_params['pretrain_embed_extend'],
            trans_r_1=forward_params['D1'],
            trans_r_2=forward_params['D2'],
            embed_r=forward_params['V_embed_extend'],
            trans_wildcard=forward_params['wildcard_mat'],
            config=args,
        )
        input_x = torch.LongTensor([[t2i['BOS']]])
        if torch.cuda.is_available():
            temp_model.cuda()
            input_x = input_x.cuda()
        h1_forward = temp_model.viterbi(input_x, None).detach()
        h1_forward = h1_forward.reshape(-1)

        if args.bidirection:
            temp_model = FSARNNIntegrateEmptyStateSaperateGRU(
                pretrained_embed=backward_params['pretrain_embed_extend'],
                trans_r_1=backward_params['D1'],
                trans_r_2=backward_params['D2'],
                embed_r=backward_params['V_embed_extend'],
                trans_wildcard=backward_params['wildcard_mat'],
                config=args,
            )
            input_x = torch.LongTensor([[t2i['EOS']]])
            if torch.cuda.is_available():
                temp_model.cuda()
                input_x = input_x.cuda()
            h1_backward = temp_model.viterbi(input_x, None).detach()
            h1_backward = h1_backward.reshape(-1)

        args.farnn = 1

    if args.bidirection:
        model = IntentIntegrateSaperateBidirection_B(
            pretrained_embed=forward_params['pretrain_embed_extend'],
            forward_params=forward_params,
            backward_params=backward_params,
            config=args,
            h1_forward=h1_forward,
            h1_backward=h1_backward)
    else:
        model = IntentIntegrateSaperate_B(
            pretrained_embed=forward_params['pretrain_embed_extend'],
            trans_r_1=forward_params['D1'],
            trans_r_2=forward_params['D2'],
            embed_r=forward_params['V_embed_extend'],
            trans_wildcard=forward_params['wildcard_mat'],
            config=args,
            mat=forward_params['mat'],
            bias=forward_params['bias'],
            h1_forward=h1_forward,
        )

    if args.loss_type == 'CrossEntropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.loss_type == 'NormalizeNLL':
        criterion = relu_normalized_NLLLoss
    else:
        print("Wrong loss function")

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=0)
    if args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=0)

    if torch.cuda.is_available():
        model = model.cuda()

    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print('ALL TRAINABLE PARAMETERS: {}'.format(pytorch_total_params))

    # TRAIN
    acc_train_init, avg_loss_train_init, train_init_p, train_init_r = val(
        model,
        intent_dataloader_train,
        epoch=0,
        mode='TRAIN',
        config=args,
        i2in=i2in,
        logger=logger,
        criterion=criterion)
    # DEV
    acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r = val(
        model,
        intent_dataloader_dev,
        epoch=0,
        mode='DEV',
        config=args,
        i2in=i2in,
        logger=logger,
        criterion=criterion)
    # TEST
    acc_test_init, avg_loss_test_init, test_init_p, test_init_r = val(
        model,
        intent_dataloader_test,
        epoch=0,
        mode='TEST',
        config=args,
        i2in=i2in,
        logger=logger,
        criterion=criterion)

    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'TRAIN', acc_train_init, avg_loss_train_init, train_init_p,
        train_init_r))
    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'DEV', acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r))
    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'TEST', acc_test_init, avg_loss_test_init, test_init_p, test_init_r))
    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'TRAIN', acc_train_init, avg_loss_train_init, train_init_p,
        train_init_r))
    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'DEV', acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r))
    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".format(
        'TEST', acc_test_init, avg_loss_test_init, test_init_p, test_init_r))

    if args.only_probe:
        exit(0)

    best_dev_acc = acc_dev_init
    counter = 0
    best_dev_model = deepcopy(model)

    if not intent_dataloader_train: args.epoch = 0

    for epoch in range(1, args.epoch + 1):
        avg_loss = 0
        acc = 0

        pbar_train = tqdm(intent_dataloader_train)
        pbar_train.set_description("TRAIN EPOCH {}".format(epoch))

        model.train()
        for batch in pbar_train:

            optimizer.zero_grad()

            x_forward = batch['x_forward']
            x_backward = batch['x_backward']
            label = batch['i'].view(-1)
            lengths = batch['l']

            if torch.cuda.is_available():
                x_forward = batch['x_forward'].cuda()
                x_backward = batch['x_backward'].cuda()
                lengths = lengths.cuda()
                label = label.cuda()

            if args.bidirection:
                scores = model(x_forward, x_backward, lengths)
            else:
                scores = model(x_forward, lengths)

            loss = criterion(scores, label)

            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

            acc += (scores.argmax(1) == label).sum().item()

            pbar_train.set_postfix_str(
                "{} - total right: {}, total loss: {}".format(
                    'TRAIN', acc, loss))

        acc = acc / len(intent_data_train)
        avg_loss = avg_loss / len(intent_data_train)
        print("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))
        logger.add("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))

        # DEV
        acc_dev, avg_loss_dev, p, r = val(model,
                                          intent_dataloader_dev,
                                          epoch,
                                          'DEV',
                                          logger,
                                          config=args,
                                          criterion=criterion)
        counter += 1  # counter for early stopping

        if (acc_dev is None) or (acc_dev > best_dev_acc):
            counter = 0
            best_dev_acc = acc_dev
            best_dev_model = deepcopy(model)

        if counter > args.early_stop:
            break

    best_dev_test_acc, avg_loss_test, best_dev_test_p, best_dev_test_r \
        = val(best_dev_model, intent_dataloader_test, epoch, 'TEST', logger, config=args, criterion=criterion)

    # Save the model
    datetime_str = create_datetime_str()
    model_save_path = "../model/{}/D{:.4f}-T{:.4f}-DI{:.4f}-TI{:.4f}-{}-{}-{}".format(
        args.run, best_dev_acc, best_dev_test_acc, acc_dev_init, acc_test_init,
        datetime_str, args.dataset, args.seed)
    mkdir("../model/{}/".format(args.run))
    mkdir(model_save_path)
    print("SAVING MODEL {} .....".format(model_save_path))
    torch.save(model.state_dict(), model_save_path + '.model')

    return acc_dev_init, acc_test_init, best_dev_acc, best_dev_test_acc, best_dev_test_p, best_dev_test_r, logger.record
示例#3
0
def PredictByRE(
    args,
    params=None,
    dset=None,
):
    logger = Logger()
    if not dset:
        dset = load_classification_dataset(args.dataset)

    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, train_query_inverse, train_lengths = pad_dataset(
        query_train, args, t2i['<pad>'])
    dev_query, dev_query_inverse, dev_lengths = pad_dataset(
        query_dev, args, t2i['<pad>'])
    test_query, test_query_inverse, test_lengths = pad_dataset(
        query_test, args, t2i['<pad>'])

    intent_data_train = ATISIntentBatchDataset(train_query, train_lengths,
                                               intent_train)
    intent_data_dev = ATISIntentBatchDataset(dev_query, dev_lengths,
                                             intent_dev)
    intent_data_test = ATISIntentBatchDataset(test_query, test_lengths,
                                              intent_test)

    intent_dataloader_train = DataLoader(intent_data_train, batch_size=args.bz)
    intent_dataloader_dev = DataLoader(intent_data_dev, batch_size=args.bz)
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    if params is None:
        automata_dicts = load_pkl(args.automata_path_forward)
        automata = automata_dicts['automata']
        language_tensor, state2idx, wildcard_mat, language = dfa_to_tensor(
            automata, t2i)
        complete_tensor = language_tensor + wildcard_mat
        if args.dataset == 'ATIS':
            mat, bias = create_mat_and_bias_with_empty_ATIS(
                automata,
                in2i=in2i,
                i2in=i2in,
            )
        elif args.dataset == 'TREC':
            mat, bias = create_mat_and_bias_with_empty_TREC(
                automata,
                in2i=in2i,
                i2in=i2in,
            )
        elif args.dataset == 'SMS':
            mat, bias = create_mat_and_bias_with_empty_SMS(
                automata,
                in2i=in2i,
                i2in=i2in,
            )
    else:
        complete_tensor = params['complete_tensor']
        mat, bias = params['mat'], params['bias']

    # for padding
    V, S1, S2 = complete_tensor.shape
    complete_tensor_extend = np.concatenate(
        (complete_tensor, np.zeros((1, S1, S2))))
    print(complete_tensor_extend.shape)
    model = IntentIntegrateOnehot(complete_tensor_extend,
                                  config=args,
                                  mat=mat,
                                  bias=bias)

    if torch.cuda.is_available():
        model.cuda()
    # # TRAIN
    print('RE TRAIN ACC')
    all_pred_train, all_out_train = REclassifier(model,
                                                 intent_dataloader_train,
                                                 config=args,
                                                 i2in=i2in)
    # DEV
    print('RE DEV ACC')
    all_pred_dev, all_out_dev = REclassifier(model,
                                             intent_dataloader_dev,
                                             config=args,
                                             i2in=i2in)
    # TEST
    print('RE TEST ACC')
    all_pred_test, all_out_test = REclassifier(model,
                                               intent_dataloader_test,
                                               config=args,
                                               i2in=i2in)

    return all_pred_train, all_pred_dev, all_pred_test, all_out_train, all_out_dev, all_out_test
示例#4
0
文件: train.py 项目: zeta1999/RE2RNN
def train_onehot(args, paths):
    logger = Logger()

    dset = load_classification_dataset(args.dataset)
    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, train_query_inverse, train_lengths = pad_dataset(
        query_train, args, t2i['<pad>'])
    dev_query, dev_query_inverse, dev_lengths = pad_dataset(
        query_dev, args, t2i['<pad>'])
    test_query, test_query_inverse, test_lengths = pad_dataset(
        query_test, args, t2i['<pad>'])

    shots = int(len(train_query) * args.train_portion)
    assert args.train_portion == 1.0
    # We currently not support ublabel and low-resource for onehot
    intent_data_train = ATISIntentBatchDataset(train_query, train_lengths,
                                               intent_train, shots)
    intent_data_dev = ATISIntentBatchDataset(dev_query, dev_lengths,
                                             intent_dev, shots)
    intent_data_test = ATISIntentBatchDataset(test_query, test_lengths,
                                              intent_test)

    intent_dataloader_train = DataLoader(intent_data_train, batch_size=args.bz)
    intent_dataloader_dev = DataLoader(intent_data_dev, batch_size=args.bz)
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    automata_dicts = load_pkl(paths[0])
    if 'automata' not in automata_dicts:
        automata = automata_dicts
    else:
        automata = automata_dicts['automata']

    language_tensor, state2idx, wildcard_mat, language = dfa_to_tensor(
        automata, t2i)
    complete_tensor = language_tensor + wildcard_mat

    assert args.additional_state == 0

    if args.dataset == 'ATIS':
        mat, bias = create_mat_and_bias_with_empty_ATIS(
            automata,
            in2i=in2i,
            i2in=i2in,
        )
    elif args.dataset == 'TREC':
        mat, bias = create_mat_and_bias_with_empty_TREC(
            automata,
            in2i=in2i,
            i2in=i2in,
        )
    elif args.dataset == 'SMS':
        mat, bias = create_mat_and_bias_with_empty_SMS(
            automata,
            in2i=in2i,
            i2in=i2in,
        )

    # for padding
    V, S1, S2 = complete_tensor.shape
    complete_tensor_extend = np.concatenate(
        (complete_tensor, np.zeros((1, S1, S2))))
    print(complete_tensor_extend.shape)
    model = IntentIntegrateOnehot(complete_tensor_extend,
                                  config=args,
                                  mat=mat,
                                  bias=bias)

    mode = 'onehot'
    if args.loss_type == 'CrossEntropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.loss_type == 'NormalizeNLL':
        criterion = relu_normalized_NLLLoss
    else:
        print("Wrong loss function")

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=0)
    if args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=0)

    if torch.cuda.is_available():
        model = model.cuda()

    acc_train_init, avg_loss_train_init, p, r = val(model,
                                                    intent_dataloader_train,
                                                    epoch=0,
                                                    mode='TRAIN',
                                                    config=args,
                                                    i2in=i2in,
                                                    criterion=criterion)
    # DEV
    acc_dev_init, avg_loss_dev_init, p, r = val(model,
                                                intent_dataloader_dev,
                                                epoch=0,
                                                mode='DEV',
                                                config=args,
                                                i2in=i2in,
                                                criterion=criterion)
    # TEST
    acc_test_init, avg_loss_test_init, p, r = val(model,
                                                  intent_dataloader_test,
                                                  epoch=0,
                                                  mode='TEST',
                                                  config=args,
                                                  i2in=i2in,
                                                  criterion=criterion)

    best_dev_acc = acc_dev_init
    counter = 0
    best_dev_test_acc = acc_test_init

    for epoch in range(1, args.epoch + 1):
        avg_loss = 0
        acc = 0

        pbar_train = tqdm(intent_dataloader_train)
        pbar_train.set_description("TRAIN EPOCH {}".format(epoch))

        model.train()
        for batch in pbar_train:

            optimizer.zero_grad()

            x = batch['x']
            label = batch['i'].view(-1)
            lengths = batch['l']

            if torch.cuda.is_available():
                x = x.cuda()
                lengths = lengths.cuda()
                label = label.cuda()

            scores = model(x, lengths)
            loss_cross_entropy = criterion(scores, label)
            loss = loss_cross_entropy

            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

            acc += (scores.argmax(1) == label).sum().item()

            pbar_train.set_postfix_str(
                "{} - total right: {}, total loss: {}".format(
                    'TRAIN', acc, loss))

        acc = acc / len(intent_data_train)
        avg_loss = avg_loss / len(intent_data_train)
        print("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))
        logger.add("{} Epoch: {} | ACC: {}, LOSS: {}".format(
            'TRAIN', epoch, acc, avg_loss))

        # DEV
        acc_dev, avg_loss_dev, p, r = val(model,
                                          intent_dataloader_dev,
                                          epoch,
                                          'DEV',
                                          logger,
                                          config=args,
                                          criterion=criterion)
        # TEST
        acc_test, avg_loss_test, p, r = val(model,
                                            intent_dataloader_test,
                                            epoch,
                                            'TEST',
                                            logger,
                                            config=args,
                                            criterion=criterion)

        counter += 1  # counter for early stopping

        if (acc_dev is None) or (acc_dev > best_dev_acc):
            counter = 0
            best_dev_acc = acc_dev
            best_dev_test_acc = acc_test

        if counter > args.early_stop:
            break

    return acc_dev_init, acc_test_init, best_dev_acc, best_dev_test_acc, logger.record
示例#5
0
def train_fsa_rnn(args, paths):
    logger = Logger()

    dset = load_classification_dataset(args.dataset)
    t2i, i2t, in2i, i2in = dset['t2i'], dset['i2t'], dset['in2i'], dset['i2in']
    query_train, intent_train = dset['query_train'], dset['intent_train']
    query_dev, intent_dev = dset['query_dev'], dset['intent_dev']
    query_test, intent_test = dset['query_test'], dset['intent_test']

    len_stats(query_train)
    len_stats(query_dev)
    len_stats(query_test)
    # extend the padding
    # add pad <pad> to the last of vocab
    i2t[len(i2t)] = '<pad>'
    t2i['<pad>'] = len(i2t) - 1

    train_query, train_query_inverse, train_lengths = pad_dataset(query_train, args, t2i['<pad>'])
    dev_query, dev_query_inverse, dev_lengths = pad_dataset(query_dev, args, t2i['<pad>'])
    test_query, test_query_inverse, test_lengths = pad_dataset(query_test, args, t2i['<pad>'])
    shots = int(len(train_query) * args.train_portion)

    intent_data_train = ATISIntentBatchDatasetBidirection(train_query, train_query_inverse, train_lengths, intent_train, shots)
    intent_data_dev = ATISIntentBatchDatasetBidirection(dev_query, dev_query_inverse, dev_lengths, intent_dev)
    intent_data_test = ATISIntentBatchDatasetBidirection(test_query, test_query_inverse, test_lengths, intent_test, )

    intent_dataloader_train = DataLoader(intent_data_train, batch_size=args.bz) if intent_data_train else None
    intent_dataloader_dev = DataLoader(intent_data_dev, batch_size=args.bz) if intent_data_dev else None
    intent_dataloader_test = DataLoader(intent_data_test, batch_size=args.bz)

    print('len train dataset {}'.format(len(intent_data_train) if intent_data_train else 0))
    print('len dev dataset {}'.format(len(intent_data_dev) if intent_data_dev else 0))
    print('len test dataset {}'.format(len(intent_data_test)))
    print('num labels: {}'.format(len(in2i)))
    print('num vocabs: {}'.format(len(t2i)))

    forward_params = dict()
    forward_params['V_embed_extend'], forward_params['pretrain_embed_extend'], forward_params['mat'], forward_params['bias'], \
    forward_params['D1'], forward_params['D2'], forward_params['language_mask'], forward_params['language'], forward_params['wildcard_mat'], \
    forward_params['wildcard_mat_origin_extend'] = \
        get_init_params(args, in2i, i2in, t2i, paths[0])

    model = IntentClassification(pretrained_embed=forward_params['pretrain_embed_extend'],
                                 trans_r_1=forward_params['D1'],
                                 trans_r_2=forward_params['D2'],
                                 embed_r=forward_params['V_embed_extend'],
                                 trans_wildcard=forward_params['wildcard_mat'],
                                 config=args,
                                 mat=forward_params['mat'],
                                 bias=forward_params['bias'])

    criterion = torch.nn.CrossEntropyLoss()

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=0)
    if args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0)

    if torch.cuda.is_available():
        model = model.cuda()

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('ALL TRAINABLE PARAMETERS: {}'.format(pytorch_total_params))

    # TRAIN
    acc_train_init, avg_loss_train_init, train_init_p, train_init_r = \
        val(model, intent_dataloader_train, epoch=0, mode='TRAIN', config=args,
            i2in=i2in, logger=logger, criterion=criterion)

    # DEV
    acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r = \
        val(model, intent_dataloader_dev, epoch=0, mode='DEV', config=args,
            i2in=i2in, logger=logger, criterion=criterion)
    # TEST
    acc_test_init, avg_loss_test_init, test_init_p, test_init_r = \
        val(model, intent_dataloader_test, epoch=0, mode='TEST', config=args,
            i2in=i2in, logger=logger, criterion=criterion)

    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
          format('TRAIN', acc_train_init, avg_loss_train_init, train_init_p, train_init_r))
    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
          format('DEV', acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r))
    print("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
          format('TEST', acc_test_init, avg_loss_test_init, test_init_p, test_init_r))

    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
               format('TRAIN', acc_train_init, avg_loss_train_init, train_init_p, train_init_r))
    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
               format('DEV', acc_dev_init, avg_loss_dev_init, dev_init_p, dev_init_r))
    logger.add("{} INITIAL: ACC: {}, LOSS: {}, P: {}, R: {}".
               format('TEST', acc_test_init, avg_loss_test_init, test_init_p, test_init_r))

    best_dev_acc = acc_dev_init
    counter = 0
    best_dev_model = deepcopy(model)

    for epoch in range(1, args.epoch + 1):
        avg_loss = 0
        acc = 0

        pbar_train = tqdm(intent_dataloader_train)
        pbar_train.set_description("TRAIN EPOCH {}".format(epoch))

        model.train()
        for batch in pbar_train:

            optimizer.zero_grad()

            x_forward = batch['x_forward']
            label = batch['i'].view(-1)
            lengths = batch['l']

            if torch.cuda.is_available():
                x_forward = batch['x_forward'].cuda()
                lengths = lengths.cuda()
                label = label.cuda()

            scores = model(x_forward, lengths)
            loss = criterion(scores, label)

            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

            acc += (scores.argmax(1) == label).sum().item()

            pbar_train.set_postfix_str("{} - total right: {}, total loss: {}".format('TRAIN', acc, loss))

        acc = acc / len(intent_data_train)
        avg_loss = avg_loss / len(intent_data_train)
        print("{} Epoch: {} | ACC: {}, LOSS: {}".format('TRAIN', epoch, acc, avg_loss))
        logger.add("{} Epoch: {} | ACC: {}, LOSS: {}".format('TRAIN', epoch, acc, avg_loss))

        # DEV
        acc_dev, avg_loss_dev, p, r = val(model, intent_dataloader_dev, epoch, 'DEV', logger, config=args, criterion=criterion)
        counter += 1  # counter for early stopping

        if (acc_dev is None) or (acc_dev > best_dev_acc):
            counter = 0
            best_dev_acc = acc_dev
            best_dev_model = deepcopy(model)

        if counter > args.early_stop:
            break

    best_dev_test_acc, avg_loss_test, best_dev_test_p, best_dev_test_r \
        = val(best_dev_model, intent_dataloader_test, epoch, 'TEST', logger, config=args, criterion=criterion)

    results = [acc_dev_init, acc_test_init, best_dev_acc, best_dev_test_acc]
    save_args_and_results(args, results, logger)