Esempio n. 1
0
def train(seed):

    print('random seed:', seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.enabled = False

    dataset = read_data('../data', '../graph')
    label2id = dataset.label2id
    print(label2id)

    vocab_size = dataset.vocab_size
    output_dim = len(label2id)

    def acc_to_str(acc):
        s = ['%s:%.3f' % (label, acc[label]) for label in acc]
        return '{' + ', '.join(s) + '}'

    cross_res = {label: [] for label in label2id if label != 'O'}
    output_file = open('%s.mistakes' % args.output, 'w')

    for cross_valid in range(5):

        model = GNN(vocab_size=vocab_size, output_dim=output_dim, args=args)
        model.cuda()
        # print vocab_size

        dataset.split_train_valid_test([0.8, 0.1, 0.1], 5, cross_valid)
        print('train:', len(dataset.train), 'valid:', len(dataset.valid),
              'test:', len(dataset.test))

        def evaluate(model, datalist, output_file=None):
            if output_file != None:
                output_file.write(
                    '#############################################\n')
            correct = {label: 0 for label in label2id if label != 'O'}
            total = len(datalist)
            model.eval()
            print_cnt = 0
            for data in datalist:
                word, feat = Variable(data.input_word).cuda(), Variable(
                    data.input_feat).cuda()
                a_ud, a_lr = Variable(data.a_ud,
                                      requires_grad=False).cuda(), Variable(
                                          data.a_lr,
                                          requires_grad=False).cuda()
                mask = Variable(data.mask, requires_grad=False).cuda()
                if args.globalnode:
                    logprob, form = model(word, feat, mask, a_ud, a_lr)
                    logprob = logprob.data.view(-1, output_dim)
                else:
                    logprob = model(word, feat, mask, a_ud,
                                    a_lr).data.view(-1, output_dim)
                mask = mask.data.view(-1)
                y_pred = torch.LongTensor(output_dim)
                for i in range(output_dim):
                    prob = logprob[:, i].exp() * mask
                    y_pred[i] = prob.topk(k=1)[1][0]
                # y_pred = logprob.topk(k=1,dim=0)[1].view(-1)
                for label in label2id:
                    if label == 'O':
                        continue
                    labelid = label2id[label]
                    if data.output.view(-1)[y_pred[labelid]] == labelid:
                        correct[label] += 1
                    else:
                        if output_file != None:
                            num_sent, sent_len, word_len = data.input_word.size(
                            )
                            id = y_pred[label2id[label]]
                            word = data.words[data.sents[int(
                                id / sent_len)][id % sent_len]]
                            output_file.write(
                                '%d %d %s %s\n' %
                                (data.set_id, data.fax_id, label, word))
            return {label: float(correct[label]) / total for label in correct}

        batch = 1

        weight = torch.zeros(len(label2id))
        for label, id in label2id.items():
            weight[id] = 1 if label == 'O' else 10
        loss_function = nn.NLLLoss(weight.cuda(), reduce=False)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr / float(batch),
                                     weight_decay=args.wd)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        best_acc = -1
        wait = 0

        for epoch in range(args.epochs):
            sum_loss = 0
            model.train()
            # random.shuffle(dataset.train)
            for idx, data in enumerate(dataset.train):
                word, feat = Variable(data.input_word).cuda(), Variable(
                    data.input_feat).cuda()
                a_ud, a_lr = Variable(data.a_ud,
                                      requires_grad=False).cuda(), Variable(
                                          data.a_lr,
                                          requires_grad=False).cuda()
                mask = Variable(data.mask, requires_grad=False).cuda()
                true_output = Variable(data.output).cuda()
                if args.globalnode:
                    logprob, form = model(word, feat, mask, a_ud, a_lr)
                else:
                    logprob = model(word, feat, mask, a_ud, a_lr)
                loss = torch.mean(
                    mask.view(-1) * loss_function(logprob.view(-1, output_dim),
                                                  true_output.view(-1)))
                if args.globalnode:
                    true_form = Variable(torch.LongTensor([data.set_id - 1
                                                           ])).cuda()
                    loss = loss + 0.1 * F.nll_loss(form, true_form)
                sum_loss += loss.data.sum()
                loss.backward()
                if (idx + 1) % batch == 0 or idx + 1 == len(dataset.train):
                    optimizer.step()
                    optimizer.zero_grad()
            train_acc = evaluate(model, dataset.train)
            valid_acc = evaluate(model, dataset.valid)
            test_acc = evaluate(model, dataset.test)
            print('Epoch %d:  Train Loss: %.3f  Train: %s  Valid: %s  Test: %s' \
                % (epoch, sum_loss, acc_to_str(train_acc), acc_to_str(valid_acc), acc_to_str(test_acc)))
            # scheduler.step()

            acc = np.log(list(valid_acc.values())).sum()
            if epoch < 6:
                continue
            if acc >= best_acc:
                torch.save(model.state_dict(), args.output + '.model')
            wait = 0 if acc > best_acc else wait + 1
            best_acc = max(acc, best_acc)
            if wait >= args.patience:
                break

        model.load_state_dict(torch.load(args.output + '.model'))
        test_acc = evaluate(model, dataset.test, output_file=output_file)
        print('########', acc_to_str(test_acc))
        for label in test_acc:
            cross_res[label].append(test_acc[label])

    print("Cross Validation Result:")
    for label in cross_res:
        cross_res[label] = np.mean(cross_res[label])
    print(acc_to_str(cross_res))
    return cross_res
Esempio n. 2
0
def train(dataset):

    print('random seed:', args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.enabled = False

    cross_res = {label: [] for label in label2id if label != 'O'}

    for cross_valid in range(1):

        # print('cross_valid', cross_valid)

        model = GNN(word_vocab_size=WORD_VOCAB_SIZE,
                    char_vocab_size=CHAR_VOCAB_SIZE,
                    d_output=d_output,
                    args=args)
        model.cuda()
        # print vocab_size

        # print('split dataset')
        # dataset.split_train_valid_test_bycase([0.5, 0.1, 0.4], 5, cross_valid)
        print('train:', len(dataset.train), 'valid:', len(dataset.valid),
              'test:', len(dataset.test))
        sys.stdout.flush()

        train_dataloader = DataLoader(dataset.train,
                                      batch_size=args.batch,
                                      shuffle=True)
        valid_dataloader = DataLoader(dataset.valid, batch_size=args.batch)
        test_dataloader = DataLoader(dataset.test, batch_size=args.batch)

        weight = torch.zeros(len(label2id))
        for label, idx in label2id.items():
            weight[idx] = 1 if label == 'O' else 2
        loss_function = nn.CrossEntropyLoss(weight.cuda(), reduce=False)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     weight_decay=args.wd)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        best_acc = -1
        wait = 0
        batch_cnt = 0

        for epoch in range(args.epochs):
            total_loss = 0
            pending_loss = None
            model.train()
            # random.shuffle(dataset.train)
            load_time, forward_time, backward_time = 0, 0, 0
            model.clear_time()

            train_log = open(args.save_path + '_train.log', 'w')
            for tensors, batch in tqdm(train_dataloader,
                                       file=train_log,
                                       mininterval=60):
                # print(batch[0].case_id, batch[0].doc_id, batch[0].page_id)
                start = time.time()
                data, data_word, pos, length, mask, label, adjs = to_var(
                    tensors, cuda=args.cuda)
                batch_size, docu_len, sent_len, word_len = data.size()
                load_time += (time.time() - start)

                start = time.time()
                logit = model(data, data_word, pos, length, mask, adjs)
                forward_time += (time.time() - start)

                start = time.time()
                if args.crf:
                    logit = logit.view(batch_size * docu_len, sent_len, -1)
                    mask = mask.view(batch_size * docu_len, -1)
                    length = length.view(batch_size * docu_len)
                    label = label.view(batch_size * docu_len, -1)
                    loss = -model.crf_layer.loglikelihood(
                        logit, mask, length, label)
                    loss = torch.masked_select(loss, torch.gt(length,
                                                              0)).mean()
                else:
                    loss = loss_function(logit.view(-1, d_output),
                                         label.view(-1))
                    loss = torch.masked_select(loss, mask.view(-1)).mean()
                total_loss += loss.data.sum()
                # print(total_loss, batch[0].case_id, batch[0].doc_id, batch[0].page_id)
                if math.isnan(total_loss):
                    print('Loss is NaN!')
                    exit()

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                backward_time += (time.time() - start)

                batch_cnt += 1
                if batch_cnt % 20000 != 0:
                    continue
                # print('load %f   forward %f   backward %f'%(load_time, forward_time, backward_time))
                # model.print_time()
                valid_acc, valid_prec, valid_recall, valid_f1 = evaluate(
                    model, valid_dataloader, args=args)

                print('Epoch %d:  Train Loss: %.3f  Valid Acc: %.5f' %
                      (epoch, total_loss, valid_acc))
                # print(acc_to_str(valid_f1))
                # scheduler.step()

                acc = np.mean(list(valid_f1.values()))  # valid_acc
                print(acc)
                if acc >= best_acc:
                    obj = {'args': args, 'model': model.state_dict()}
                    torch.save(obj, args.save_path + '.model')
                    result_obj['valid_prec'] = np.mean(
                        list(valid_prec.values()))
                    result_obj['valid_recall'] = np.mean(
                        list(valid_recall.values()))
                    result_obj['valid_f1'] = np.mean(list(valid_f1.values()))
                wait = 0 if acc > best_acc else wait + 1
                best_acc = max(acc, best_acc)

                model.train()
                sys.stdout.flush()
                if wait >= args.patience:
                    break

            train_log.close()
            os.remove(args.save_path + '_train.log')

            if wait >= args.patience:
                break

        obj = torch.load(args.save_path + '.model')
        model.load_state_dict(obj['model'])

        test(test_dataloader, model)

    # print("Cross Validation Result:")
    # for label in cross_res:
    #     cross_res[label] = np.mean(cross_res[label])
    # print(acc_to_str(cross_res))
    return cross_res