示例#1
0
文件: main.py 项目: jdddog/GTS
def train(args):

    # load dataset
    train_sentence_packs = json.load(open(args.prefix + args.dataset + '/train.json'))
    random.shuffle(train_sentence_packs)
    dev_sentence_packs = json.load(open(args.prefix + args.dataset + '/dev.json'))
    instances_train = load_data_instances(train_sentence_packs, args)
    instances_dev = load_data_instances(dev_sentence_packs, args)
    random.shuffle(instances_train)
    trainset = DataIterator(instances_train, args)
    devset = DataIterator(instances_dev, args)

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)
    model = MultiInferBert(args).to(args.device)

    optimizer = torch.optim.Adam([
        {'params': model.bert.parameters(), 'lr': 5e-5},
        {'params': model.cls_linear.parameters()}
    ], lr=5e-5)

    best_joint_f1 = 0
    best_joint_epoch = 0
    for i in range(args.epochs):
        print('Epoch:{}'.format(i))
        for j in trange(trainset.batch_count):
            _, tokens, lengths, masks, _, _, aspect_tags, tags = trainset.get_batch(j)
            preds = model(tokens, masks)

            preds_flatten = preds.reshape([-1, preds.shape[3]])
            tags_flatten = tags.reshape([-1])
            loss = F.cross_entropy(preds_flatten, tags_flatten, ignore_index=-1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        joint_precision, joint_recall, joint_f1 = eval(model, devset, args)

        if joint_f1 > best_joint_f1:
            model_path = args.model_dir + 'bert' + args.task + '.pt'
            torch.save(model, model_path)
            best_joint_f1 = joint_f1
            best_joint_epoch = i
    print('best epoch: {}\tbest dev {} f1: {:.5f}\n\n'.format(best_joint_epoch, args.task, best_joint_f1))
示例#2
0
文件: main.py 项目: jdddog/GTS
def train(args):

    # load double embedding
    word2index = json.load(open(args.prefix + 'doubleembedding/word_idx.json'))
    general_embedding = numpy.load(args.prefix + 'doubleembedding/gen.vec.npy')
    general_embedding = torch.from_numpy(general_embedding)
    domain_embedding = numpy.load(args.prefix + 'doubleembedding/' +
                                  args.dataset + '_emb.vec.npy')
    domain_embedding = torch.from_numpy(domain_embedding)

    # load dataset
    train_sentence_packs = json.load(
        open(args.prefix + args.dataset + '/train.json'))
    random.shuffle(train_sentence_packs)
    dev_sentence_packs = json.load(
        open(args.prefix + args.dataset + '/dev.json'))

    instances_train = load_data_instances(train_sentence_packs, word2index,
                                          args)
    instances_dev = load_data_instances(dev_sentence_packs, word2index, args)

    random.shuffle(instances_train)
    trainset = DataIterator(instances_train, args)
    devset = DataIterator(instances_dev, args)

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    # build model
    if args.model == 'bilstm':
        model = MultiInferRNNModel(general_embedding, domain_embedding,
                                   args).to(args.device)
    elif args.model == 'cnn':
        model = MultiInferCNNModel(general_embedding, domain_embedding,
                                   args).to(args.device)

    parameters = list(model.parameters())
    parameters = filter(lambda x: x.requires_grad, parameters)
    optimizer = torch.optim.Adam(parameters, lr=args.lr)

    # training
    best_joint_f1 = 0
    best_joint_epoch = 0
    for i in range(args.epochs):
        print('Epoch:{}'.format(i))
        for j in trange(trainset.batch_count):
            _, sentence_tokens, lengths, masks, aspect_tags, _, tags = trainset.get_batch(
                j)
            predictions = model(sentence_tokens, lengths, masks)

            loss = 0.
            tags_flatten = tags[:, :lengths[0], :lengths[0]].reshape([-1])
            for k in range(len(predictions)):
                prediction_flatten = predictions[k].reshape(
                    [-1, predictions[k].shape[3]])
                loss = loss + F.cross_entropy(
                    prediction_flatten, tags_flatten, ignore_index=-1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        joint_precision, joint_recall, joint_f1 = eval(model, devset, args)

        if joint_f1 > best_joint_f1:
            model_path = args.model_dir + args.model + args.task + '.pt'
            torch.save(model, model_path)
            best_joint_f1 = joint_f1
            best_joint_epoch = i
    print('best epoch: {}\tbest dev {} f1: {:.5f}\n\n'.format(
        best_joint_epoch, args.task, best_joint_f1))