Ejemplo n.º 1
0
            w.write('{}: {:.5f}\n'.format(k, family_correct[k] / v))

    print('Avg Acc: {:.5f}'.format(
        sum(family_correct.values()) / sum(family_total.values())))

    clevr.close()


if __name__ == '__main__':
    with open('data/dic.pkl', 'rb') as f:
        dic = pickle.load(f)

    n_words = len(dic['word_dic']) + 1
    n_answers = len(dic['answer_dic'])

    net = MACNetwork(n_words, dim).to(device)
    net_running = MACNetwork(n_words, dim).to(device)
    accumulate(net_running, net, 0)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    for epoch in range(n_epoch):
        train(epoch)
        valid(epoch)

        with open(
                'checkpoint/checkpoint_{}.model'.format(
                    str(epoch + 1).zfill(2)), 'wb') as f:
            torch.save(net_running.state_dict(), f)
Ejemplo n.º 2
0
        'Avg Acc: {:.5f}'.format(
            sum(family_correct.values()) / sum(family_total.values())
        )
    )

    clevr.close()


if __name__ == '__main__':
    with open('data/dic.pkl', 'rb') as f:
        dic = pickle.load(f)

    n_words = len(dic['word_dic']) + 1
    n_answers = len(dic['answer_dic'])

    net = MACNetwork(n_words, dim).to(device)
    net_running = MACNetwork(n_words, dim).to(device)
    accumulate(net_running, net, 0)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    for epoch in range(n_epoch):
        train(epoch)
        valid(epoch)

        with open(
            'checkpoint/checkpoint_{}.model'.format(str(epoch + 1).zfill(2)), 'wb'
        ) as f:
            torch.save(net, f)
Ejemplo n.º 3
0
if __name__ == '__main__':
    cfg, mode = generate_cfg()

    experiment = Experiment(
        api_key = os.environ["COMET_KEY"],
        project_name=cfg.COMET.PROJECT_NAME,
        workspace = os.environ["COMET_USER"],
    )

    experiment.log_parameters(config_to_comet(cfg))
    experiment.add_tag("NEW_MAC")
    if cfg.COMET.EXPERIMENT_NAME:
        experiment.set_name(cfg.COMET.EXPERIMENT_NAME)

    net = MACNetwork(cfg).to(device)
    if cfg.MAC.TRAINED_EMBD_PATH:
        load_pretrained_embedings(cfg, net.embed, device)
    net_running = MACNetwork(cfg).to(device)

    if cfg.LOAD:
        with open(cfg.LOAD_PATH, 'rb') as f:
            state = torch.load(f, map_location=device)
        net.load_state_dict(state, strict=False)
    accumulate(net_running, net, 0)

    if cfg.MAC.USE_ACT and cfg.ACT.SMOOTH:
        _criterion = nn.NLLLoss()
        criterion = lambda x, y: _criterion(torch.log(x.clamp(min=1e-8)), y)
    else:
        criterion = nn.CrossEntropyLoss()
Ejemplo n.º 4
0
    if args.comet:
        experiment = Experiment(api_key='VD0MYyhx0BQcWhxWvLbcalX51',
                                project_name="MAC")
        experiment.set_name(args.exp_name)
        experiment.log_parameters(params_to_dic())

    with open(f'{root}/dic.pkl', 'rb') as f:
        dic = pickle.load(f)

    n_words = len(dic['word_dic']) + 1
    n_answers = len(dic['answer_dic'])

    net = MACNetwork(n_words,
                     MAC_UNIT_DIM[dataset_type],
                     classes=n_answers,
                     max_step=MAX_STEPS,
                     self_attention=USE_SELF_ATTENTION,
                     memory_gate=USE_MEMORY_GATE,
                     num_heads=NUM_HEADS)
    net = nn.DataParallel(net)
    net.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=BASE_LR)
    best_loss = float('inf')
    best_acc = float('inf')
    best_epoch = float('inf')

    for epoch in range(TRAIN_EPOCHS):
        if args.comet:
            with experiment.train():
Ejemplo n.º 5
0
def main(clevr_dir,
         load_filename=None,
         n_epochs=20,
         n_memories=3,
         only_test=False):
    with open(os.path.join(clevr_dir, "preprocessed/dic.pkl"), "rb") as f:
        dic = pickle.load(f)

    n_words = len(dic["word_dic"]) + 1
    n_answers = len(dic["answer_dic"])

    net = MACNetwork(n_words, dim, n_memories=n_memories, save_attns=only_test)
    accum_net = MACNetwork(n_words,
                           dim,
                           n_memories=n_memories,
                           save_attns=only_test)
    net = net.to(device)
    accum_net = accum_net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    start_epoch = 0

    if device.type == "cuda":
        devices = [0] if only_test else None
        print("Using", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net, device_ids=devices)
        accum_net = nn.DataParallel(accum_net, device_ids=devices)

    if load_filename:
        checkpoint = torch.load(load_filename)
        if checkpoint.get("model_state_dict", None) is None:
            # old format - just the net, not a dict of stuff
            print("Loading old-format checkpoint...")
            net.load_state_dict(checkpoint)
        else:
            # new format
            net.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            start_epoch = checkpoint["epoch"] + 1
            print(f"Starting at epoch {start_epoch+1}")

    accumulate(accum_net, net, 0)  # copy net's parameters to accum_net

    if not only_test:
        # do training and validation
        for epoch in range(start_epoch, n_epochs):
            train(net, accum_net, optimizer, criterion, clevr_dir, epoch)
            valid(accum_net, clevr_dir, epoch)

            with open(
                    f"checkpoint/checkpoint_{str(epoch + 1).zfill(2)}_{n_memories}m.model",
                    "wb") as f:

                torch.save(
                    {
                        "epoch":
                        epoch + 1,
                        "model_state_dict": (accum_net.module if isinstance(
                            accum_net, nn.DataParallel) else
                                             accum_net).state_dict(),
                        "optimizer_state_dict":
                        optimizer.state_dict(),
                    },
                    f,
                )
    else:
        # predict on the test set and make visualization data
        test(accum_net, clevr_dir)