예제 #1
0
파일: train.py 프로젝트: sufeidechabei/dgl
def main(args):
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    th.cuda.manual_seed(args.seed)

    best_epoch = -1
    best_dev_acc = 0

    cuda = args.gpu >= 0
    device = th.device('cuda:{}'.format(
        args.gpu)) if cuda else th.device('cpu')
    if cuda:
        th.cuda.set_device(args.gpu)

    trainset = data.SST()
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
                              collate_fn=data.SST.batcher(device),
                              shuffle=True,
                              num_workers=0)
    devset = data.SST(mode='dev')
    dev_loader = DataLoader(dataset=devset,
                            batch_size=100,
                            collate_fn=data.SST.batcher(device),
                            shuffle=False,
                            num_workers=0)

    testset = data.SST(mode='test')
    test_loader = DataLoader(dataset=testset,
                             batch_size=100,
                             collate_fn=data.SST.batcher(device),
                             shuffle=False,
                             num_workers=0)

    model = TreeLSTM(trainset.num_vocabs,
                     args.x_size,
                     args.h_size,
                     trainset.num_classes,
                     args.dropout,
                     cell_type='childsum' if args.child_sum else 'nary',
                     pretrained_emb=trainset.pretrained_emb).to(device)
    print(model)
    params_ex_emb = [
        x for x in list(model.parameters())
        if x.requires_grad and x.size(0) != trainset.num_vocabs
    ]
    params_emb = list(model.embedding.parameters())

    for p in params_ex_emb:
        if p.dim() > 1:
            INIT.xavier_uniform_(p)

    optimizer = optim.Adagrad([{
        'params': params_ex_emb,
        'lr': args.lr,
        'weight_decay': args.weight_decay
    }, {
        'params': params_emb,
        'lr': 0.1 * args.lr
    }])

    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
        model.train()
        for step, batch in enumerate(train_loader):
            g = batch.graph
            n = g.number_of_nodes()
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
            if step >= 3:
                t0 = time.time()  # tik

            logits = model(batch, h, c)
            logp = F.log_softmax(logits, 1)
            loss = F.nll_loss(logp, batch.label, reduction='sum')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step >= 3:
                dur.append(time.time() - t0)  # tok

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
                acc = th.sum(th.eq(batch.label, pred))
                root_ids = [
                    i for i in range(batch.graph.number_of_nodes())
                    if batch.graph.out_degree(i) == 0
                ]
                root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] ==
                                  pred.cpu().data.numpy()[root_ids])

                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}"
                    .format(epoch, step, loss.item(),
                            1.0 * acc.item() / len(batch.label),
                            1.0 * root_acc / len(root_ids), np.mean(dur)))
        print('Epoch {:05d} training time {:.4f}s'.format(
            epoch,
            time.time() - t_epoch))

        # eval on dev set
        accs = []
        root_accs = []
        model.eval()
        for step, batch in enumerate(dev_loader):
            g = batch.graph
            n = g.number_of_nodes()
            with th.no_grad():
                h = th.zeros((n, args.h_size)).to(device)
                c = th.zeros((n, args.h_size)).to(device)
                logits = model(batch, h, c)

            pred = th.argmax(logits, 1)
            acc = th.sum(th.eq(batch.label, pred)).item()
            accs.append([acc, len(batch.label)])
            root_ids = [
                i for i in range(batch.graph.number_of_nodes())
                if batch.graph.out_degree(i) == 0
            ]
            root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] ==
                              pred.cpu().data.numpy()[root_ids])
            root_accs.append([root_acc, len(root_ids)])

        dev_acc = 1.0 * np.sum([x[0]
                                for x in accs]) / np.sum([x[1] for x in accs])
        dev_root_acc = 1.0 * np.sum([x[0] for x in root_accs]) / np.sum(
            [x[1] for x in root_accs])
        print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
            epoch, dev_acc, dev_root_acc))

        if dev_root_acc > best_dev_acc:
            best_dev_acc = dev_root_acc
            best_epoch = epoch
            th.save(model.state_dict(), 'best_{}.pkl'.format(args.seed))
        else:
            if best_epoch <= epoch - 10:
                break

        # lr decay
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(1e-5, param_group['lr'] * 0.99)  #10
            print(param_group['lr'])

    # test
    model.load_state_dict(th.load('best_{}.pkl'.format(args.seed)))
    accs = []
    root_accs = []
    model.eval()
    for step, batch in enumerate(test_loader):
        g = batch.graph
        n = g.number_of_nodes()
        with th.no_grad():
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
            logits = model(batch, h, c)

        pred = th.argmax(logits, 1)
        acc = th.sum(th.eq(batch.label, pred)).item()
        accs.append([acc, len(batch.label)])
        root_ids = [
            i for i in range(batch.graph.number_of_nodes())
            if batch.graph.out_degree(i) == 0
        ]
        root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] ==
                          pred.cpu().data.numpy()[root_ids])
        root_accs.append([root_acc, len(root_ids)])

    test_acc = 1.0 * np.sum([x[0]
                             for x in accs]) / np.sum([x[1] for x in accs])
    test_root_acc = 1.0 * np.sum([x[0] for x in root_accs]) / np.sum(
        [x[1] for x in root_accs])
    print(
        '------------------------------------------------------------------------------------'
    )
    print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
        best_epoch, test_acc, test_root_acc))
예제 #2
0
def main(args):
    mini_str = '/mini' if args.mini else ''  # path to mini dataset
    version_suffix = '_v2.0' if args.squad_version == 2.0 else ''  # gets proper dataset version (1.1 or 2.0)

    # Prepare output directory under ./weights/ to store model-specific data including weights
    out_dir = 'weights/%s' % args.experiment
    if os.path.exists(out_dir):
        print(
            'Warning - you are overwriting previous experiment %s. Hit Ctrl Z to abort.\n'
            % args.experiment)
        shutil.rmtree(out_dir)
    os.mkdir(out_dir)
    logger = open(os.path.join(out_dir, 'log.txt'), 'w')
    print_and_log(
        'Timestamp = %s for %s\n' %
        (datetime.strftime(datetime.now(), '%m/%d/%Y %H:%M'), args.experiment),
        logger)

    # Load Dev Data and save it to this model's weights dir
    print_and_log('Loading v%s Dev Data...' % args.squad_version, logger)
    dev_data = load_pk('preprocess/data%s/squad_dev_trees%s.npy' %
                       (mini_str, version_suffix))
    dev_batcher = Batcher(dev_data,
                          is_train=False,
                          target_batch_size=args.batch_size)
    save_as_pk(dev_batcher, os.path.join(out_dir, 'dev_batcher.npy'))
    print_and_log('Loaded Dev Data...', logger)

    # Load Train Data and save it to this model's weights dir
    print_and_log('Loading v%s Train Data...' % args.squad_version, logger)
    train_data = load_pk('preprocess/data%s/squad_train_trees%s.npy' %
                         (mini_str, version_suffix))
    train_batcher = Batcher(train_data,
                            is_train=True,
                            target_batch_size=args.batch_size)
    print_and_log('Loaded Train Data...', logger)

    # Create models and optimizers
    span_extractor = TreeLSTM(use_cuda=args.cuda)
    answer_verifier = AnswerVerifier(use_cuda=args.cuda)

    if args.cuda:
        span_extractor.cuda()
        answer_verifier.cuda()

    span_extractor_grad_params = filter(lambda p: p.requires_grad,
                                        span_extractor.parameters())
    span_extractor_optimizer = optim.Adam(span_extractor_grad_params,
                                          args.span_extractor_lr)

    answer_verifier_grad_params = filter(lambda p: p.requires_grad,
                                         answer_verifier.parameters())
    answer_verifier_optimizer = optim.Adam(answer_verifier_grad_params,
                                           args.answer_verifier_lr)

    # Determines if question is answerable or not
    answer_verifier_logistic_loss = BCEWithLogitsLoss(
        pos_weight=span_extractor.cudify(torch.FloatTensor([0.5])))

    best_span_f1 = -1  # Keep track of which epoch model achieves highest span level F1 on the dev set
    best_answer_verifier_accuracy = -1
    best_span_epoch = -1
    best_answer_verifier_epoch = -1
    for epoch_idx in range(args.epochs):
        print_and_log('Starting Epoch %d...' % (epoch_idx + 1), logger)

        train_evaluator = Evaluator(
            'train'
        )  # Stores predictions and returns evaluation string at the end of epoch
        dev_evaluator = Evaluator('dev')

        start_time = time()

        span_extractor.train()
        answer_verifier.train()
        while train_batcher.has_next():
            # Clear gradients and get next batch
            span_extractor_optimizer.zero_grad()
            answer_verifier_optimizer.zero_grad()

            joint_loss = _run_batch(
                batch=train_batcher.next(),
                span_extractor=span_extractor,
                span_extractor_optimizer=span_extractor_optimizer,
                answer_verifier=answer_verifier,
                answer_verifier_optimizer=answer_verifier_optimizer,
                answer_verifier_logistic_loss=answer_verifier_logistic_loss,
                evaluator=train_evaluator)

            joint_loss.backward()

            # Make a gradient step
            span_extractor_optimizer.step()
            answer_verifier_optimizer.step()
        print_and_log('Took %s.' % format_seconds(time() - start_time), logger)
        print_and_log('\t' + train_evaluator.eval_string(), logger)

        span_extractor.eval()
        answer_verifier.eval()
        while dev_batcher.has_next():
            _run_batch(
                batch=dev_batcher.next(),
                span_extractor=span_extractor,
                span_extractor_optimizer=span_extractor_optimizer,
                answer_verifier=answer_verifier,
                answer_verifier_optimizer=answer_verifier_optimizer,
                answer_verifier_logistic_loss=answer_verifier_logistic_loss,
                evaluator=dev_evaluator)

        print_and_log('\t' + dev_evaluator.eval_string(), logger)
        dev_f1 = dev_evaluator.span_f1()
        if dev_f1 > best_span_f1:
            best_span_f1 = dev_f1
            best_span_epoch = epoch_idx + 1
            torch.save(span_extractor,
                       os.path.join(out_dir, 'best_span_extractor.tar'))

        dev_answer_verifier_accuracy = dev_evaluator.avg_answer_accuracy()
        if dev_answer_verifier_accuracy > best_answer_verifier_accuracy:
            best_answer_verifier_accuracy = dev_answer_verifier_accuracy
            best_answer_verifier_epoch = epoch_idx + 1
            torch.save(answer_verifier,
                       os.path.join(out_dir, 'best_answer_verifier.tar'))

    print_and_log(
        '\nBest span = %.4f F1 at %d epoch' % (best_span_f1, best_span_epoch),
        logger)
    print_and_log(
        '\nBest answer verifier = %.4f accuracy at %d epoch' %
        (best_answer_verifier_accuracy, best_answer_verifier_epoch), logger)