Beispiel #1
0
def main(args):
    # data
    split = args.val_split
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                            split=split, max_examples=args.max_examples)
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            num_workers=args.workers)

    # model
    model, model_args = utils.load_model(fname=args.start_from,
                                         return_args=True,
                                         ngpus=args.cuda)
    model.eval()
    criterion = nn.NLLLoss()

    # evaluate metrics on dataset
    preds = []
    batches = []
    for batch_idx, batch in batch_iter(dataloader, args, volatile=True):
        if batch_idx % 50 == 0:
            print('Batch {}/{}'.format(batch_idx, len(dataloader)))
        # forward
        pred = model(batch)
        loss = criterion(pred, batch['answer'])

        # visualization
        preds.append(pred)
        batches.append(batch)

    # save webpage that displays example predictions
    with open(pth.join(args.figqa_pre, 'vocab.json'), 'r') as f:
        vocab = json.load(f)
    html = render_webpage(batches, preds, args.examples_dir, vocab)
    with open(pth.join(args.examples_dir, 'examples.html'), 'w') as f:
        f.write(html)
Beispiel #2
0
def main(args):
    global running_loss, start_t
    # logging info that needs to persist across iterations
    viz = utils.visualize.VisdomVisualize(env_name=args.env_name)
    viz.viz.text(str(args))
    running_loss = None
    running_accs = {qtype: 0.5 for qtype, _ in enumerate(utils.QTYPE_ID_TO_META)}
    start_t = None

    # data
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                           split='train1')
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            num_workers=args.workers, pin_memory=True,
                            shuffle=bool(args.shuffle_train))
    val_dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                               split=args.val_split)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size,
                                num_workers=args.workers, pin_memory=True,
                                shuffle=True)

    # model
    if args.start_from:
        model = utils.load_model(fname=args.start_from, ngpus=args.cuda)
    else:
        model_args = figqa.options.model_args(args)
        model = utils.load_model(model_args, ngpus=args.cuda)

    # optimization
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
    def exp_lr(epoch):
        iters = epoch * len(dataloader)
        return args.lr_decay**iters
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, exp_lr)
    criterion = nn.NLLLoss()

    # training
    for epoch in range(args.epochs):
        checkpoint_stuff(**locals())
        scheduler.step()
        start_t = timer()
        # TODO: understand when/why automatic garbage collection slows down
        # the train loop
        gc.disable()
        for local_iter_idx, batch in batch_iter(dataloader, args):
            iter_idx = local_iter_idx + epoch * len(dataloader)

            # forward + update
            optimizer.zero_grad()
            pred = model(batch)
            loss = criterion(pred, batch['answer'])
            loss.backward()
            optimizer.step()

            # visualize, log, checkpoint
            log_stuff(**locals())
        gc.enable()
Beispiel #3
0
def main(args):
    running_loss = None
    start_t = None

    # data
    split = args.val_split
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre, split=split)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers)

    # model
    model, model_args = utils.load_model(fname=args.start_from,
                                         return_args=True,
                                         ngpus=args.cuda)
    model.eval()
    criterion = nn.NLLLoss()

    # evaluate metrics on dataset
    accs = []
    accs_by_qtype = {
        qtype: []
        for qtype, _ in enumerate(utils.QTYPE_ID_TO_META)
    }
    start_t = timer()
    for batch_idx, batch in batch_iter(dataloader, args, volatile=True):
        if batch_idx % 50 == 0:
            print('Batch {}/{}'.format(batch_idx, len(dataloader)))
        # forward
        pred = model(batch)
        loss = criterion(pred, batch['answer'])

        # accuracy
        _, pred_idx = torch.max(pred, dim=1)
        correct = (batch['answer'] == pred_idx)
        acc = correct.cpu().data.numpy()
        accs.append(acc)
        for qtype, meta in enumerate(utils.QTYPE_ID_TO_META):
            qtype_mask = (batch['qtype'] == qtype)
            if qtype_mask.sum().data[0] == 0:
                continue
            acc = correct[qtype_mask].cpu().data.numpy()
            accs_by_qtype[qtype].append(acc)

    # accumulate results into convenient dict
    accs = np.concatenate(accs, axis=0)
    for qtype in accs_by_qtype:
        qaccs = accs_by_qtype[qtype]
        accs_by_qtype[qtype] = np.concatenate(qaccs, axis=0).mean()
    result = {
        'split': split,
        'model_kind': model_args['model'],
        'acc': accs.mean(),
        'accs_by_qtype': accs_by_qtype,
        'qtypes': [qt[0] for qt in utils.QTYPE_ID_TO_META],
    }
    pprint(result)
    result['args'] = args
    result_name = args.result_name

    # save to disk
    name = 'result_{split}_{result_name}.pkl'.format(**locals())
    result_fname = pth.join(args.result_dir, name)
    os.makedirs(args.result_dir, exist_ok=True)
    with open(result_fname, 'wb') as f:
        pkl.dump(result, f)
Beispiel #4
0
def log_stuff(iter_idx, loss, batch, pred, val_dataloader, model,
              criterion, epoch, optimizer, running_accs, viz, args,
              **kwargs):
    global running_loss, start_t
    #######################################################################
    # report numbers on this train batch
    if iter_idx % 100 != 0:
        return
    # loss
    alpha = .70
    if running_loss is None:
        running_loss = loss.data[0]
    else:
        running_loss = alpha * running_loss + (1 - alpha) * loss.data[0]
    viz.append_data(iter_idx, running_loss, 'Loss', 'running loss')

    # accuracy
    _, pred_idx = torch.max(pred, dim=1)
    correct = (batch['answer'] == pred_idx)
    train_acc = correct.cpu().data.numpy().mean()
    viz.append_data(iter_idx, train_acc, 'Acc', 'acc')

    # learning rate
    viz.append_data(iter_idx, optimizer.param_groups[0]['lr'], 'Learning rate', 'lr', ytype='log')

    # accuracy by question type
    for qtype, meta in enumerate(utils.QTYPE_ID_TO_META):
        qtype_mask = (batch['qtype'] == qtype)
        if qtype_mask.sum().data[0] != 0:
            qtype_correct = correct[qtype_mask]
            qtype_acc = qtype_correct.sum().data[0] / qtype_correct.size(0)
            running_accs[qtype] = 0.20 * qtype_acc + \
                                  (1 - 0.20) * running_accs[qtype]
        viz.append_data(iter_idx, running_accs[qtype],
                        'Train Question Type Acc', meta[0] + ' ' + str(meta[1]))

    # print to command line
    end_t = timer()
    time_stamp = strftime('%a %d %b %y %X', gmtime())
    t_diff = end_t - start_t
    log_line = ('[{time_stamp}][Ep: {epoch:0>2d}][Iter: {iter_idx}]'
                '[Time: {t_diff:.2f}][Loss: {running_loss:.4f}]')
    print(log_line.format(running_loss=running_loss, **locals()))
    start_t = end_t

    #######################################################################
    # numbers on a few batches of val
    if iter_idx % 500 != 0:
        return
    val_batches = 10
    val_losses = []
    val_accs = []
    val_correct_by_qtype = {qtype: [] for qtype, _ in
                                            enumerate(utils.QTYPE_ID_TO_META)}
    for _, val_batch in islice(batch_iter(val_dataloader, args, volatile=True), val_batches):
        val_pred = model(val_batch)
        val_loss = criterion(val_pred, val_batch['answer']).cpu().data.numpy()
        val_losses.append(val_loss)
        _, val_pred_idx = torch.max(val_pred, dim=1)
        val_correct = (val_batch['answer'] == val_pred_idx)
        val_acc = val_correct.cpu().data.numpy().mean()
        val_accs.append(val_acc)
        # accuracy by question type
        for qtype, meta in enumerate(utils.QTYPE_ID_TO_META):
            qtype_mask = (val_batch['qtype'] == qtype)
            if qtype_mask.sum().data[0] == 0:
                continue
            qtype_correct = val_correct[qtype_mask]
            val_correct_by_qtype[qtype].append(qtype_correct)

    # plot stuff
    viz.append_data(iter_idx, np.mean(val_losses), 'Loss', 'val loss')
    viz.append_data(iter_idx, np.mean(val_accs), 'Acc', 'val acc')
    acc_per_chart_type = defaultdict(lambda: [])
    for qtype, meta in enumerate(utils.QTYPE_ID_TO_META):
        correct = sum(c.sum().data[0] for c in val_correct_by_qtype[qtype])
        total = sum(c.size(0) for c in val_correct_by_qtype[qtype])
        qtype_acc = correct / total if total > 0 else 0.5
        viz.append_data(iter_idx, qtype_acc, 'Val Question Type Acc',
                        meta[0] + ' ' + str(meta[1]))
        chart_type = meta[1]
        acc_per_chart_type[chart_type].append(qtype_acc)
    for chart_type in acc_per_chart_type:
        acc = np.mean(acc_per_chart_type[chart_type])
        viz.append_data(iter_idx, acc, 'Val Chart Type Acc', str(chart_type))