Ejemplo n.º 1
0
def main(args):
    global running_loss, start_t
    # logging info that needs to persist across iterations
    viz = utils.visualize.VisdomVisualize(env_name=args.env_name)
    viz.viz.text(str(args))
    running_loss = None
    running_accs = {qtype: 0.5 for qtype, _ in enumerate(utils.QTYPE_ID_TO_META)}
    start_t = None

    # data
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                           split='train1')
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            num_workers=args.workers, pin_memory=True,
                            shuffle=bool(args.shuffle_train))
    val_dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                               split=args.val_split)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size,
                                num_workers=args.workers, pin_memory=True,
                                shuffle=True)

    # model
    if args.start_from:
        model = utils.load_model(fname=args.start_from, ngpus=args.cuda)
    else:
        model_args = figqa.options.model_args(args)
        model = utils.load_model(model_args, ngpus=args.cuda)

    # optimization
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
    def exp_lr(epoch):
        iters = epoch * len(dataloader)
        return args.lr_decay**iters
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, exp_lr)
    criterion = nn.NLLLoss()

    # training
    for epoch in range(args.epochs):
        checkpoint_stuff(**locals())
        scheduler.step()
        start_t = timer()
        # TODO: understand when/why automatic garbage collection slows down
        # the train loop
        gc.disable()
        for local_iter_idx, batch in batch_iter(dataloader, args):
            iter_idx = local_iter_idx + epoch * len(dataloader)

            # forward + update
            optimizer.zero_grad()
            pred = model(batch)
            loss = criterion(pred, batch['answer'])
            loss.backward()
            optimizer.step()

            # visualize, log, checkpoint
            log_stuff(**locals())
        gc.enable()
Ejemplo n.º 2
0
def main(args):
    # data
    split = args.val_split
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre,
                            split=split, max_examples=args.max_examples)
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            num_workers=args.workers)

    # model
    model, model_args = utils.load_model(fname=args.start_from,
                                         return_args=True,
                                         ngpus=args.cuda)
    model.eval()
    criterion = nn.NLLLoss()

    # evaluate metrics on dataset
    preds = []
    batches = []
    for batch_idx, batch in batch_iter(dataloader, args, volatile=True):
        if batch_idx % 50 == 0:
            print('Batch {}/{}'.format(batch_idx, len(dataloader)))
        # forward
        pred = model(batch)
        loss = criterion(pred, batch['answer'])

        # visualization
        preds.append(pred)
        batches.append(batch)

    # save webpage that displays example predictions
    with open(pth.join(args.figqa_pre, 'vocab.json'), 'r') as f:
        vocab = json.load(f)
    html = render_webpage(batches, preds, args.examples_dir, vocab)
    with open(pth.join(args.examples_dir, 'examples.html'), 'w') as f:
        f.write(html)
Ejemplo n.º 3
0
def main(args):
    running_loss = None
    start_t = None

    # data
    split = args.val_split
    dataset = FigQADataset(args.figqa_dir, args.figqa_pre, split=split)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers)

    # model
    model, model_args = utils.load_model(fname=args.start_from,
                                         return_args=True,
                                         ngpus=args.cuda)
    model.eval()
    criterion = nn.NLLLoss()

    # evaluate metrics on dataset
    accs = []
    accs_by_qtype = {
        qtype: []
        for qtype, _ in enumerate(utils.QTYPE_ID_TO_META)
    }
    start_t = timer()
    for batch_idx, batch in batch_iter(dataloader, args, volatile=True):
        if batch_idx % 50 == 0:
            print('Batch {}/{}'.format(batch_idx, len(dataloader)))
        # forward
        pred = model(batch)
        loss = criterion(pred, batch['answer'])

        # accuracy
        _, pred_idx = torch.max(pred, dim=1)
        correct = (batch['answer'] == pred_idx)
        acc = correct.cpu().data.numpy()
        accs.append(acc)
        for qtype, meta in enumerate(utils.QTYPE_ID_TO_META):
            qtype_mask = (batch['qtype'] == qtype)
            if qtype_mask.sum().data[0] == 0:
                continue
            acc = correct[qtype_mask].cpu().data.numpy()
            accs_by_qtype[qtype].append(acc)

    # accumulate results into convenient dict
    accs = np.concatenate(accs, axis=0)
    for qtype in accs_by_qtype:
        qaccs = accs_by_qtype[qtype]
        accs_by_qtype[qtype] = np.concatenate(qaccs, axis=0).mean()
    result = {
        'split': split,
        'model_kind': model_args['model'],
        'acc': accs.mean(),
        'accs_by_qtype': accs_by_qtype,
        'qtypes': [qt[0] for qt in utils.QTYPE_ID_TO_META],
    }
    pprint(result)
    result['args'] = args
    result_name = args.result_name

    # save to disk
    name = 'result_{split}_{result_name}.pkl'.format(**locals())
    result_fname = pth.join(args.result_dir, name)
    os.makedirs(args.result_dir, exist_ok=True)
    with open(result_fname, 'wb') as f:
        pkl.dump(result, f)