예제 #1
0
def evaluate_ent(model, dev_iter):
    model.eval()
    acc, ent, size = 0, 0, 0
    for premise, hypothesis, label in dev_iter:
        output = F.softmax(forward(model, premise, hypothesis), 1)
        ent += entropy(output).sum().data.cpu().numpy()[0]
        _, pred = output.max(dim=1)
        if isinstance(label, Variable):
            label = label.data.cpu()
        acc += (pred.data.cpu() == label).sum()
        size += hypothesis.shape[0]
    ent = ent / size
    acc /= size
    acc = acc
    return ent, acc
예제 #2
0
def evaluate_ent(model, data):
    model.eval()
    ent, acc, size = 0, 0, 0
    for i, (v, q, a, idx, q_len) in enumerate(data):
        if i > 10:
            break
        v = Variable(v.cuda(async=True))
        q = Variable(q.cuda(async=True))
        a = Variable(a.cuda(async=True))
        q_len = Variable(q_len.cuda(async=True))
        out = F.softmax(model(v, q, q_len), 1)
        ent += entropy(out).sum().data.cpu()[0]
        acc += utils.batch_accuracy(out.data, a.data).mean()
        size += 1
    return ent / size, acc / size
예제 #3
0
def main():
    from args import conf, tune_conf
    parser = argparse.ArgumentParser()
    parser.add_argument('--baseline', default='results/baseline.pt')
    parser.add_argument(
        '--ent-train',
        default='/scratch0/shifeng/rawr/new_snli/rawr.train.pkl')
    parser.add_argument('--ent-dev',
                        default='/scratch0/shifeng/rawr/new_snli/rawr.dev.pkl')
    args = parser.parse_args()

    out_dir = prepare_output_dir(args, args.root_dir)
    log = logging.getLogger(__name__)
    log.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(out_dir, 'output.log'))
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s %(message)s',
                                  datefmt='%m/%d/%Y %I:%M:%S')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    log.addHandler(fh)
    log.addHandler(ch)
    log.info('===== {} ====='.format(out_dir))
    ''' load regular data '''
    log.info('loading regular training data')
    data = SNLI(conf)
    conf.char_vocab_size = len(data.char_vocab)
    conf.word_vocab_size = len(data.TEXT.vocab)
    conf.class_size = len(data.LABEL.vocab)
    conf.max_word_len = data.max_word_len

    log.info('loading entropy dev data {}'.format(tune_conf.ent_dev))
    with open(tune_conf.ent_dev, 'rb') as f:
        ent_dev = pickle.load(f)
    if isinstance(ent_dev[0], list):
        ent_dev = list(itertools.chain(*ent_dev))
    log.info('{} entropy dev examples'.format(len(ent_dev)))
    ent_dev = [[
        x['data']['premise'], x['data']['hypothesis'], x['data']['label']
    ] for x in ent_dev]

    log.info('loading entropy training data {}'.format(tune_conf.ent_train))
    with open(tune_conf.ent_train, 'rb') as f:
        ent_train = pickle.load(f)
    if isinstance(ent_train[0], list):
        ent_train = list(itertools.chain(*ent_train))
    log.info('{} entropy training examples'.format(len(ent_train)))
    ent_train = [[
        x['data']['premise'], x['data']['hypothesis'], x['data']['label']
    ] for x in ent_train]

    train_ent_batches = batchify(ent_train, tune_conf.batch_size)
    log.info('{} entropy training batches'.format(len(train_ent_batches)))

    log.info('loading model from {}'.format(args.baseline))
    model = BIMPM(conf, data)
    model.load_state_dict(torch.load(args.baseline))
    # model.word_emb.weight.requires_grad = True
    model.cuda(conf.gpu)

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = optim.Adam(parameters, lr=tune_conf.lr)
    ent_optimizer = optim.Adam(parameters, lr=tune_conf.ent_lr)
    criterion = nn.CrossEntropyLoss()

    init_loss, init_acc = evaluate(model, data.dev_iter)
    log.info("initial loss {:.4f} accuracy {:.4f}".format(init_loss, init_acc))
    best_acc = init_acc

    dev_ent_batches = batchify(ent_dev, tune_conf.batch_size)
    init_ent, init_ent_acc = evaluate_ent(model, dev_ent_batches)
    log.info("initial entropy {:.4f} ent_acc {:.4f}".format(
        init_ent, init_ent_acc))

    epoch = 0
    i_ent, i_mle = 0, 0  # number of examples
    train_loss, train_ent = 0, 0
    train_mle_iter = iter(data.train_iter)
    train_ent_iter = iter(train_ent_batches)
    while True:
        model.train()
        for i in range(tune_conf.n_ent):
            try:
                prem, hypo, label = next(train_ent_iter)
            except StopIteration:
                random.shuffle(train_ent_batches)
                train_ent_iter = iter(train_ent_batches)
                i_ent = 0
                train_ent = 0
                break
            output = forward(model, prem, hypo, conf.max_sent_len)
            output = F.softmax(output, 1)
            ent = entropy(output).sum()
            train_ent += ent.data.cpu().numpy()[0]
            loss = -tune_conf.gamma * ent
            ent_optimizer.zero_grad()
            loss.backward()
            ent_optimizer.step()
            i_ent += prem.shape[0]

        end_of_epoch = False
        for i in range(tune_conf.n_mle):
            if i_mle >= len(data.train_iter):
                epoch += 1
                end_of_epoch = True
                data.train_iter.init_epoch()
                train_mle_iter = iter(data.train_iter)
                i_mle = 0
                train_loss = 0
                break
            batch = next(train_mle_iter)
            output = forward(model, batch.premise, batch.hypothesis,
                             conf.max_sent_len)
            loss = criterion(output, batch.label)
            train_loss += loss.data.cpu().numpy()[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            i_mle += batch.premise.shape[0]

        if i_mle % 1000 == 0:
            _loss = train_loss / i_mle if i_mle != 0 else 0
            _ent = train_ent / i_ent if i_ent != 0 else 0
            log.info(
                'epoch [{:2}] [{} / {}] loss[{:.5f}] entropy[{:.5f}]'.format(
                    epoch, i_mle, len(data.train_iter), _loss, _ent))

        if end_of_epoch or i_mle % 1e5 == 0:
            dev_loss, dev_acc = evaluate(model, data.dev_iter)
            dev_ent, dev_ent_acc = evaluate_ent(model, dev_ent_batches)
            log.info("dev acc: {:.4f} ent: {:.4f} ent_acc: {:.4f}".format(
                dev_acc, dev_ent, dev_ent_acc))
            model_path = os.path.join(out_dir,
                                      'checkpoint_epoch_{}.pt'.format(epoch))
            torch.save(model.state_dict(), model_path)
            if dev_acc > best_acc:
                best_acc = dev_acc
                model_path = os.path.join(out_dir, 'best_model.pt')
                torch.save(model.state_dict(), model_path)
                log.info("best model saved {}".format(dev_acc))

        if epoch > 40:
            break
예제 #4
0
def main():
    args = argparse.Namespace()
    args.gamma = 2e-4
    args.n_mle = 2
    args.n_ent = 2
    args.ent_lr = 1e-4

    out_dir = prepare_output_dir(user_dir='/scratch0/shifeng/rawr_data/vqa')
    log = logging.getLogger(__name__)
    log.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(out_dir, 'output.log'))
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s %(message)s',
                                  datefmt='%m/%d/%Y %I:%M:%S')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    log.addHandler(fh)
    log.addHandler(ch)
    log.info('===== {} ====='.format(out_dir))

    ckp = torch.load('2017-08-04_00.55.19.pth')
    vocab_size = len(ckp['vocab']['question']) + 1
    net = nn.DataParallel(model.Net(vocab_size))
    net.load_state_dict(ckp['weights'])
    net.cuda()
    parameters = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters)
    ent_optimizer = optim.Adam(parameters, lr=args.ent_lr)
    log_softmax = nn.LogSoftmax().cuda()

    dev_ent_iter = data.get_reduced_loader('results/rawr.dev.pkl', val=True)
    log.info('{} entropy dev examples'.format(len(dev_ent_iter)))
    train_ent_data = data.get_reduced_loader('results/rawr.train.pkl',
                                             train=True)
    log.info('{} entropy training examples'.format(len(train_ent_data)))

    dev_mle_iter = data.get_loader(val=True)
    train_mle_data = data.get_loader(train=True)

    dev_loss, dev_acc = evaluate(net, dev_mle_iter)
    log.info('dev loss {:.4f} acc {:.4f}'.format(dev_loss, dev_acc))
    dev_ent, dev_ent_acc = evaluate_ent(net, dev_ent_iter)
    log.info('dev ent {:.4f} acc {:.4f}'.format(dev_ent, dev_ent_acc))
    best_acc = dev_acc
    # best_acc = 0

    epoch = 0
    i_mle, i_ent = 0, 0
    train_loss, train_ent = 0, 0
    size_mle, size_ent = 0, 0
    train_ent_iter = iter(train_ent_data)
    train_mle_iter = iter(train_mle_data)
    while True:
        net.train()
        for i in range(args.n_ent):
            try:
                v, q, a, idx, q_len = next(train_ent_iter)
                v = Variable(v.cuda(async=True))
                q = Variable(q.cuda(async=True))
                a = Variable(a.cuda(async=True))
                q_len = Variable(q_len.cuda(async=True))
            except StopIteration:
                i_ent = 0
                train_ent = 0
                train_ent_iter = iter(train_ent_data)
                break
            out = F.softmax(net(v, q, q_len), 1)
            ent = entropy(out).sum()
            train_ent += ent.data.cpu()[0]
            optimizer.zero_grad()
            ent_optimizer.zero_grad()
            loss = -args.gamma * ent
            loss.backward()
            ent_optimizer.step()
            i_ent += 1
            if i_ent > len(train_ent_data):
                i_ent = 0
                train_ent = 0
                train_ent_iter = iter(train_ent_data)
                break

        end_of_epoch = False
        for i in range(args.n_mle):
            try:
                v, q, a, idx, q_len = next(train_mle_iter)
                v = Variable(v.cuda(async=True))
                q = Variable(q.cuda(async=True))
                a = Variable(a.cuda(async=True))
                q_len = Variable(q_len.cuda(async=True))
            except StopIteration:
                i_mle = 0
                train_loss = 0
                end_of_epoch = True
                epoch += 1
                train_mle_iter = iter(train_mle_data)
                break
            out = F.softmax(net(v, q, q_len), 1)
            nll = -log_softmax(out)
            loss = (nll * a / 10).sum(dim=1).mean()
            train_loss += loss.data.cpu()[0]
            optimizer.zero_grad()
            ent_optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            i_mle += 1
            if i_mle > len(train_mle_data):
                i_mle = 0
                train_loss = 0
                end_of_epoch = True
                epoch += 1
                train_mle_iter = iter(train_mle_data)
                break

        if i_mle % 1000 == 0:
            _loss = train_loss / i_mle if i_mle != 0 else 0
            _ent = train_ent / i_ent if i_ent != 0 else 0
            log.info(
                'epoch [{:2}] [{} / {}] loss[{:.5f}] entropy[{:.5f}]'.format(
                    epoch, i_mle, len(train_mle_data), _loss, _ent))

        if end_of_epoch or i_mle % 1e5 == 0:
            dev_loss, dev_acc = evaluate(net, dev_mle_iter)
            log.info('dev loss {:.4f} acc {:.4f}'.format(dev_loss, dev_acc))
            dev_ent, dev_ent_acc = evaluate_ent(net, dev_ent_iter)
            log.info('dev ent {:.4f} acc {:.4f}'.format(dev_ent, dev_ent_acc))
            model_path = os.path.join(out_dir,
                                      'checkpoint_epoch_{}.pt'.format(epoch))
            torch.save(net.state_dict(), model_path)
            if dev_acc > best_acc:
                best_acc = dev_acc
                model_path = os.path.join(out_dir, 'best_model.pt')
                torch.save(net.state_dict(), model_path)
                log.info("best model saved {}".format(dev_acc))

        if epoch > 40:
            break
예제 #5
0
def main():
    from args import args
    os.makedirs(args.run_dir, exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    args.load_model_dir = parser.parse_args().model
    log = logging.getLogger(__name__)
    log.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(args.run_dir, 'output.log'))
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    log.addHandler(fh)
    log.addHandler(ch)
    log.info('===== {} ====='.format(args.timestamp))

    with open(os.path.join(args.run_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    input_field = data.Field(lower=args.lower)
    output_field = data.Field(sequential=False)
    train_reg, dev_reg, test = datasets.SNLI.splits(
            input_field, output_field, root=args.data_root)
    input_field.build_vocab(train_reg, dev_reg, test)
    output_field.build_vocab(train_reg)
    input_field.vocab.vectors = torch.load(args.vector_cache)
    
    train_reg_iter, dev_reg_iter = data.BucketIterator.splits(
            (train_reg, dev_reg), batch_size=300, device=args.gpu)

    with open('pkls/rawr.train.pkl', 'rb') as f:
        train_ent = pickle.load(f)['data']
    with open('pkls/rawr.dev.pkl', 'rb') as f:
        dev_ent = pickle.load(f)['data']

    train_ent_iter = EntIterator(
            train_ent, batch_size=300, device=args.gpu)
    dev_ent_iter = EntIterator(
            dev_ent, batch_size=300, device=args.gpu,
            evaluation=True)

    model = torch.load(args.load_model_dir, 
            map_location=lambda storage, location: storage.cuda(args.gpu))

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ent_optimizer = torch.optim.Adam(model.parameters(), lr=args.ent_lr)


    ''' initial evaluation '''
    model.eval()
    dev_reg_iter.init_epoch()
    n_dev_correct, dev_loss = 0, 0
    total = 0
    for dev_batch_idx, dev_batch in enumerate(dev_reg_iter):
        total += dev_batch.hypothesis.shape[1]
        answer = model(dev_batch)
        n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
        dev_acc = 100. * n_dev_correct / total
    log.info('dev acc {:.4f}'.format(dev_acc))

    dev_ent_iter.init_epoch()
    avg_entropy = 0
    total = 0
    for dev_batch_idx in range(len(dev_ent_iter)):
        dev_batch = dev_ent_iter.next()
        total += dev_batch.hypothesis.shape[1]
        output = model(dev_batch)
        ent = entropy(F.softmax(output, 1)).sum()
        avg_entropy += ent.data.cpu().numpy()[0]
    log.info('dev entropy {:.4f}'.format(avg_entropy / total))

    best_dev_acc = -1
    train_ent_iter.init_epoch()
    for epoch in range(args.epochs):
        epoch_loss = []
        epoch_entropy = []
        n_reg = 0
        n_ent = 0
        train_reg_iter.init_epoch()
        for i_reg, batch in enumerate(train_reg_iter):
            model.train()
            output = model(batch)
            loss = criterion(output, batch.label)
            epoch_loss.append(loss.data.cpu().numpy()[0])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            n_reg += 1

            if n_reg % args.n_reg_per_ent == 0:
                model.train()
                for j in range(args.n_ent_per_reg):
                    batch_ent = train_ent_iter.next()
                    output = model(batch_ent)
                    ent = entropy(F.softmax(output, 1)).sum()
                    epoch_entropy.append(ent.data.cpu().numpy()[0])
                    loss = - args.gamma * ent
                    ent_optimizer.zero_grad()
                    loss.backward()
                    ent_optimizer.step()

            if n_reg % args.n_report == 0:
                if len(epoch_loss) != 0 and len(epoch_entropy) != 0:
                    log.info('epoch [{}] batch [{}, {}] loss [{:.4f}] entropy [{:.4f}]'.format(
                        epoch, i_reg, n_ent, sum(epoch_loss) / len(epoch_loss),
                        sum(epoch_entropy) / len(epoch_entropy)))

            if n_reg % args.n_eval == 0:
                model.eval()
                dev_reg_iter.init_epoch()
                n_dev_correct, dev_loss = 0, 0
                total = 0
                for dev_batch_idx, dev_batch in enumerate(dev_reg_iter):
                    total += dev_batch.hypothesis.shape[1]
                    answer = model(dev_batch)
                    n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
                    dev_acc = 100. * n_dev_correct / total
                log.info('dev acc {:.4f}'.format(dev_acc))

                dev_ent_iter.init_epoch()
                avg_entropy = 0
                total = 0
                for dev_batch_idx in range(len(dev_ent_iter)):
                    dev_batch = dev_ent_iter.next()
                    total += dev_batch.hypothesis.shape[1]
                    output = model(dev_batch)
                    ent = entropy(F.softmax(output, 1)).sum()
                    avg_entropy += ent.data.cpu().numpy()[0]
                log.info('dev entropy {:.4f}'.format(avg_entropy / total))

                if dev_acc > best_dev_acc:
                    snapshot_path = os.path.join(args.run_dir, 'best_model.pt')
                    torch.save(model, snapshot_path)
                    best_dev_acc = dev_acc
                    log.info('save best model {}'.format(best_dev_acc))

        snapshot_path = os.path.join(args.run_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
        torch.save(model, snapshot_path)
        log.info('save model {}'.format(snapshot_path))