Example #1
0
def triplets_get_data(config, net, dataloader, sobel):
  num_batches = len(dataloader)
  flat_targets_all = torch.zeros((num_batches * config.batch_sz),
                                 dtype=torch.int32).cuda()
  flat_preds_all = torch.zeros((num_batches * config.batch_sz),
                               dtype=torch.int32).cuda()

  num_test = 0
  for b_i, batch in enumerate(dataloader):
    imgs = batch[0].cuda()

    if sobel:
      imgs = sobel_process(imgs, config.include_rgb)

    flat_targets = batch[1]

    with torch.no_grad():
      x_outs = net(imgs)

    assert (x_outs.shape[1] == config.output_k)
    assert (len(x_outs.shape) == 2)

    num_test_curr = flat_targets.shape[0]
    num_test += num_test_curr

    start_i = b_i * config.batch_sz
    flat_preds_curr = torch.argmax(x_outs, dim=1)  # along output_k
    flat_preds_all[start_i:(start_i + num_test_curr)] = flat_preds_curr

    flat_targets_all[start_i:(start_i + num_test_curr)] = flat_targets.cuda()

  flat_preds_all = flat_preds_all[:num_test]
  flat_targets_all = flat_targets_all[:num_test]

  return flat_preds_all, flat_targets_all
def assess_acc_block(
    net,
    test_loader,
    gt_k=None,
    include_rgb=None,
    penultimate_features=False,
    contiguous_sz=None,
):
    total = 0
    all = None
    all_targets = None
    dlen = None
    for imgs, targets in test_loader:
        imgs = sobel_process(imgs.cuda(), include_rgb)

        with torch.no_grad():
            x_out = net(imgs, penultimate_features=penultimate_features)

        bn, dlen = x_out.shape
        if all is None:
            all = np.zeros((len(test_loader) * bn, dlen))
            all_targets = np.zeros(len(test_loader) * bn)

        all[total : (total + bn), :] = x_out.cpu().numpy()
        all_targets[total : (total + bn)] = targets.numpy()
        total += bn
    assert dlen is not None

    # 40000
    all = all[:total, :]
    all_targets = all_targets[:total]

    num_orig, leftover = divmod(total, contiguous_sz)
    assert leftover == 0

    all = all.reshape((num_orig, contiguous_sz, dlen))
    all = all.sum(axis=1, keepdims=False) / float(contiguous_sz)

    all_targets = all_targets.reshape((num_orig, contiguous_sz))
    # sanity check
    all_targets_avg = all_targets.astype("int").sum(axis=1) / contiguous_sz
    all_targets = all_targets[:, 0].astype("int")
    assert np.array_equal(all_targets_avg, all_targets)

    preds = np.argmax(all, axis=1).astype("int")
    assert preds.min() >= 0 and preds.max() < gt_k
    assert all_targets.min() >= 0 and all_targets.max() < gt_k
    if not (preds.shape == all_targets.shape):
        print((preds.shape))
        print((all_targets.shape))
        assert False

    assert preds.shape == (num_orig,)
    correct = (preds == all_targets).sum()

    return correct / float(num_orig)
def get_dlen(net_features, dataloader, include_rgb=None, penultimate_features=False):
    dlen = None
    for imgs, _ in dataloader:
        imgs = sobel_process(imgs.cuda(), include_rgb).cpu()
        x_features = net_features(
            imgs, trunk_features=True, penultimate_features=penultimate_features
        )

        x_features = x_features.view(x_features.shape[0], -1)
        dlen = x_features.shape[1]
        break

    assert dlen is not None
    return dlen
Example #4
0
def get_dlen(net_features,
             dataloader,
             include_rgb=None,
             penultimate_features=False):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = Variable(sobel_process(imgs.cuda(), include_rgb)).cpu()
        x_features = net_features(imgs,
                                  trunk_features=True,
                                  penultimate_features=penultimate_features)

        x_features = x_features.view(x_features.shape[0], -1)
        dlen = x_features.shape[1]
        break

    return dlen
Example #5
0
def get_dlen(net_features,
             dataloader,
             include_rgb=None,
             penultimate_features=False):
    imgs, _ = next(iter(dataloader))
    img = torch.unsqueeze(imgs[0], 0)
    sobel_img = sobel_process(img.cuda(), include_rgb)

    net_features.eval()
    with torch.no_grad():
        x_features = net_features(sobel_img,
                                  trunk_features=True,
                                  penultimate_features=penultimate_features)
    net_features.train()

    dlen = x_features.view(x_features.shape[0], -1).shape[1]

    return dlen
def triplets_get_data_kmeans_on_features(config, net, dataloader, sobel):
    # ouput of network is features (not softmaxed)
    num_batches = len(dataloader)
    flat_targets_all = torch.zeros((num_batches * config.batch_sz),
                                   dtype=torch.int32).cuda()
    features_all = np.zeros((num_batches * config.batch_sz, config.output_k),
                            dtype=np.float32)

    num_test = 0
    for b_i, batch in enumerate(dataloader):
        imgs = batch[0].cuda()

        if sobel:
            imgs = sobel_process(imgs, config.include_rgb)

        flat_targets = batch[1]

        with torch.no_grad():
            x_outs = net(imgs)

        assert (x_outs.shape[1] == config.output_k)
        assert (len(x_outs.shape) == 2)

        num_test_curr = flat_targets.shape[0]
        num_test += num_test_curr

        start_i = b_i * config.batch_sz
        features_all[start_i:(start_i +
                              num_test_curr), :] = x_outs.cpu().numpy()
        flat_targets_all[start_i:(start_i +
                                  num_test_curr)] = flat_targets.cuda()

    features_all = features_all[:num_test, :]
    flat_targets_all = flat_targets_all[:num_test]

    kmeans = KMeans(n_clusters=config.gt_k).fit(features_all)
    flat_preds_all = torch.from_numpy(kmeans.labels_).cuda()

    assert (flat_targets_all.shape == flat_preds_all.shape)
    assert (max(flat_preds_all) < config.gt_k)

    return flat_preds_all, flat_targets_all
def assess_acc(
    net, test_loader, gt_k=None, include_rgb=None, penultimate_features=False
):
    correct = 0
    total = 0
    for imgs, targets in test_loader:
        imgs = sobel_process(imgs.cuda(), include_rgb)

        with torch.no_grad():
            x_out = net(imgs, penultimate_features=penultimate_features)

        # bug fix!!
        preds = np.argmax(x_out.cpu().numpy(), axis=1).astype("int")
        targets = targets.numpy().astype("int")
        assert preds.min() >= 0 and preds.max() < gt_k
        assert targets.min() >= 0 and targets.max() < gt_k
        assert preds.shape == targets.shape

        correct += (preds == targets).sum()
        total += preds.shape[0]

    return correct / float(total)
Example #8
0
if not os.path.exists(render_out_dir):
    os.makedirs(render_out_dir)

results_f = os.path.join(render_out_dir, "results.txt")

iterators = (d for d in [dataloader, render_dataloader])

for tup in itertools.izip(*iterators):
    train_batch = tup[0]
    render_batch = tup[1]

    imgs = train_batch[0].cuda()
    orig_imgs = render_batch[0]

    if sobel:
        imgs = sobel_process(imgs, config.include_rgb, using_IR=using_IR)

    flat_targets = train_batch[1]

    with torch.no_grad():
        x_outs = net(imgs)

    assert (x_outs[0].shape[1] == config.output_k)
    assert (len(x_outs[0].shape) == 2)

    x_outs_curr = x_outs[best_head]
    flat_preds_curr = torch.argmax(x_outs_curr, dim=1)  # along output_k

    with open(results_f, "w") as f:
        for i, img_i in enumerate(img_inds):
            img = orig_imgs[img_i].numpy()
Example #9
0
def main():
    config = parse_config()

    # Setup ----------------------------------------------------------------------

    config.contiguous_sz = 10  # Tencrop
    config.out_dir = os.path.join(config.out_root, str(config.model_ind))

    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    if config.restart:
        given_config = config
        reloaded_config_path = os.path.join(given_config.out_dir,
                                            "config.pickle")
        print("Loading restarting config from: %s" % reloaded_config_path)
        with open(reloaded_config_path, "rb") as config_f:
            config = pickle.load(config_f)
        assert (config.model_ind == given_config.model_ind)

        config.restart = True
        config.num_epochs = given_config.num_epochs  # train for longer

        config.restart_new_model_ind = given_config.restart_new_model_ind
        config.new_model_ind = given_config.new_model_ind

        start_epoch = config.last_epoch + 1

        print("...restarting from epoch %d" % start_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:start_epoch]
        config.epoch_loss = config.epoch_loss[:start_epoch]
    else:
        config.epoch_acc = []
        config.epoch_loss = []
        start_epoch = 0

    # old config only used retrospectively for setting up model at start
    reloaded_config_path = os.path.join(config.out_root,
                                        str(config.old_model_ind),
                                        "best_config.pickle")
    print("Loading old features config from: %s" % reloaded_config_path)
    with open(reloaded_config_path, "rb") as config_f:
        old_config = pickle.load(config_f)
        assert (old_config.model_ind == config.old_model_ind)

    if config.new_batch_sz == -1:
        config.new_batch_sz = old_config.batch_sz

    fig, axarr = plt.subplots(2, sharex=False, figsize=(20, 20))

    # Data -----------------------------------------------------------------------

    # make supervised data: train on train, test on test, unlabelled is unused
    assert old_config.sobel, "Old model should have been trained with sobel being activated"
    tf1, tf2, tf3 = sobel_make_transforms(old_config,
                                          random_affine=config.random_affine,
                                          cutout=config.cutout,
                                          cutout_p=config.cutout_p,
                                          cutout_max_box=config.cutout_max_box,
                                          affine_p=config.affine_p)

    if old_config.dataset == "STL10":
        dataset_class = torchvision.datasets.STL10
        train_data = dataset_class(root=old_config.dataset_root,
                                   transform=tf2,
                                   split="train")  # also could use tf1

        test_data = dataset_class(root=old_config.dataset_root,
                                  transform=None,
                                  split="test")
    elif old_config.dataset in HANDWRITING_DATASETS:
        dataset_root = os.path.join(old_config.dataset_root,
                                    old_config.dataset)
        train_json_path = os.path.join("train",
                                       old_config.dataset + "_train.json")
        train_data = HandwritingDataset([train_json_path],
                                        dataset_root,
                                        transform=tf2)

        test_json_path = os.path.join("test",
                                      old_config.dataset + "_test.json")
        test_data = HandwritingDataset([test_json_path],
                                       dataset_root,
                                       transform=None)
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.new_batch_sz,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=False)

    test_data = TenCropAndFinish(test_data,
                                 input_sz=old_config.input_sz,
                                 include_rgb=old_config.include_rgb)
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config.new_batch_sz,
        # full batch
        shuffle=False,
        num_workers=0,
        drop_last=False)

    # Model ----------------------------------------------------------------------

    net_features = archs.__dict__[old_config.arch](old_config).cuda()

    if not config.restart:
        model_path = os.path.join(old_config.out_dir, "best_net.pytorch")
        net_features.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    dlen = get_dlen(net_features,
                    train_loader,
                    include_rgb=old_config.include_rgb,
                    penultimate_features=config.penultimate_features)
    print("dlen: %d" % dlen)

    assert (config.arch == "SupHead5")
    net = SupHead5(net_features, dlen=dlen, gt_k=old_config.gt_k)

    if config.restart:
        print("restarting from latest net")
        model_path = os.path.join(config.out_dir, "latest_net.pytorch")
        net.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    net.cuda()
    net = torch.nn.DataParallel(net)

    opt_trunk = torch.optim.Adam(net.module.trunk.parameters(),
                                 lr=config.trunk_lr)
    opt_head = torch.optim.Adam(net.module.head.parameters(),
                                lr=config.head_lr)

    if config.restart:
        print("restarting from latest optimiser")
        optimiser_states = torch.load(
            os.path.join(config.out_dir, "latest_optimiser.pytorch"))
        opt_trunk.load_state_dict(optimiser_states["opt_trunk"])
        opt_head.load_state_dict(optimiser_states["opt_head"])
    else:
        print("using new optimiser state")

    criterion = nn.CrossEntropyLoss().cuda()

    if not config.restart:
        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print("pre: model %d old model %d, acc %f time %s" %
              (config.model_ind, config.old_model_ind, acc, datetime.now()))
        sys.stdout.flush()

        config.epoch_acc.append(acc)

    if config.restart_new_model_ind:
        assert config.restart
        config.model_ind = config.new_model_ind  # old_model_ind stays same
        config.out_dir = os.path.join(config.out_root, str(config.model_ind))
        print("restarting as model %d" % config.model_ind)

        if not os.path.exists(config.out_dir):
            os.makedirs(config.out_dir)

    # Train ----------------------------------------------------------------------

    for e_i in xrange(start_epoch, config.num_epochs):
        net.train()

        if e_i in config.lr_schedule:
            print("e_i %d, multiplying lr for opt trunk and head by %f" %
                  (e_i, config.lr_mult))
            opt_trunk = update_lr(opt_trunk, lr_mult=config.lr_mult)
            opt_head = update_lr(opt_head, lr_mult=config.lr_mult)
            if not hasattr(config, "lr_changes"):
                config.lr_changes = []
            config.lr_changes.append((e_i, config.lr_mult))

        avg_loss = 0.
        num_batches = len(train_loader)
        for i, (imgs, targets) in enumerate(train_loader):
            imgs = sobel_process(imgs.cuda(), old_config.include_rgb)
            targets = targets.cuda()

            x_out = net(imgs, penultimate_features=config.penultimate_features)
            loss = criterion(x_out, targets)

            avg_loss += float(loss.data)

            opt_trunk.zero_grad()
            opt_head.zero_grad()

            loss.backward()

            opt_trunk.step()
            opt_head.step()

            if (i % 100 == 0) or (e_i == start_epoch):
                print("batch %d of %d, loss %f, time %s" %
                      (i, num_batches, float(loss.data), datetime.now()))
                sys.stdout.flush()

        avg_loss /= num_batches

        # Eval ----------------------------------------------------------------------

        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print(
            "model %d old model %d epoch %d acc %f time %s" %
            (config.model_ind, config.old_model_ind, e_i, acc, datetime.now()))
        sys.stdout.flush()

        is_best = False
        if acc > max(config.epoch_acc):
            is_best = True

        config.epoch_acc.append(acc)
        config.epoch_loss.append(avg_loss)

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("Acc")

        axarr[1].clear()
        axarr[1].plot(config.epoch_loss)
        axarr[1].set_title("Loss")

        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % 10 == 0):
            net.module.cpu()

            if is_best:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir, "best_optimiser.pytorch"))

            # save model sparingly for this script
            if e_i % 10 == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir,
                                    "latest_optimiser.pytorch"))

            net.module.cuda()

            config.last_epoch = e_i  # for last saved version

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'w') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)
Example #10
0
    if e_i in config.lr_schedule:
        optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    avg_loss = 0.  # over heads and head_epochs (and sub_heads)
    avg_loss_count = 0

    sys.stdout.flush()

    iterators = (d for d in train_dataloaders)

    b_i = 0
    for tup in zip(*iterators):
        net.module.zero_grad()

        # no sobel yet
        imgs_orig = sobel_process(tup[0][0].cuda(), config.include_rgb)
        imgs_pos = sobel_process(tup[1][0].cuda(), config.include_rgb)
        imgs_neg = sobel_process(tup[2][0].cuda(), config.include_rgb)

        outs_orig = net(imgs_orig)
        outs_pos = net(imgs_pos)
        outs_neg = net(imgs_neg)

        curr_loss = triplets_loss(outs_orig, outs_pos, outs_neg)

        if ((b_i % 100) == 0) or (e_i == next_epoch and b_i < 10):
            print("Model ind %d epoch %d batch %d "
                  "loss %f time %s" % \
                  (config.model_ind, e_i, b_i, curr_loss.item(), datetime.now()))
            sys.stdout.flush()
Example #11
0
def training(config, net, current_epoch, next_epoch, heads, dataloaders_head_A, dataloaders_head_B, loss_fn, optimiser):
    """Computes loss for head A and B for the current epoch and carries 
    out a backward pass through the net using the optimiser with lr annealing

    Params:
      config: TODO
      net: TODO
      current_epoch: TODO
      next_epoch: TODO
      heads: TODO
      dataloaders_head_A: TODO
      dataloaders_head_B: TODO
      loss_fn: TODO
      optimiser: TODO

    Returns:
        PytorchNetwork -- the trained model
    """

    if current_epoch in config.lr_schedule:
        optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    for head_i in range(2):

        head = heads[head_i]
        if head == "A":
            dataloaders = dataloaders_head_A
            epoch_loss = config.epoch_loss_head_A
            epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
            lamb = config.lamb_A

        elif head == "B":
            dataloaders = dataloaders_head_B
            epoch_loss = config.epoch_loss_head_B
            epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
            lamb = config.lamb_B

        iterators = (d for d in dataloaders)
        b_i = 0
        avg_loss = 0.  # over heads and head_epochs (and sub_heads)
        avg_loss_no_lamb = 0.
        avg_loss_count = 0

        for tup in zip(*iterators):
            net.module.zero_grad()

            if not config.no_sobel:
                pre_channels = config.in_channels - 1
            else:
                pre_channels = config.in_channels

            all_img1 = torch.zeros(config.batch_sz, pre_channels,
                                   config.input_sz, config.input_sz).to(
                torch.float32)
            all_img2 = torch.zeros(config.batch_sz, pre_channels,
                                   config.input_sz, config.input_sz).to(
                torch.float32)
            all_affine2_to_1 = torch.zeros(config.batch_sz, 2, 3).to(
                torch.float32)
            all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz,
                                        config.input_sz).to(torch.float32)

            if not config.nocuda:
                all_img1 = all_img1.cuda()
                all_img2 = all_img2.cuda()
                all_affine2_to_1 = all_affine2_to_1.cuda()
                all_mask_img1 = all_mask_img1.cuda()

            curr_batch_sz = tup[0][0].shape[0]
            for d_i in range(config.num_dataloaders):
                img1, img2, affine2_to_1, mask_img1 = tup[d_i]
                assert (img1.shape[0] == curr_batch_sz)

                actual_batch_start = d_i * curr_batch_sz
                actual_batch_end = actual_batch_start + curr_batch_sz

                all_img1[actual_batch_start:actual_batch_end, :, :, :] = img1
                all_img2[actual_batch_start:actual_batch_end, :, :, :] = img2
                all_affine2_to_1[actual_batch_start:actual_batch_end, :,
                                 :] = affine2_to_1
                all_mask_img1[actual_batch_start:actual_batch_end,
                              :, :] = mask_img1

            if not (curr_batch_sz == config.dataloader_batch_sz) and (
                    current_epoch == next_epoch):
                print("last batch sz %d" % curr_batch_sz)

            curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  # times 2
            all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
            all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
            all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
            all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

            if (not config.no_sobel):
                all_img1 = sobel_process(all_img1, config.include_rgb,
                                         using_IR=config.using_IR, cuda_enabled=not config.nocuda)
                all_img2 = sobel_process(all_img2, config.include_rgb,
                                         using_IR=config.using_IR, cuda_enabled=not config.nocuda)

            x1_outs = net(all_img1, head=head)
            x2_outs = net(all_img2, head=head)

            avg_loss_batch = None  # avg over the heads
            avg_loss_no_lamb_batch = None

            for i in range(config.num_sub_heads):
                loss, loss_no_lamb = loss_fn(x1_outs[i],
                                             x2_outs[i],
                                             all_affine2_to_1=all_affine2_to_1,
                                             all_mask_img1=all_mask_img1,
                                             lamb=lamb,
                                             half_T_side_dense=config.half_T_side_dense,
                                             half_T_side_sparse_min=config.half_T_side_sparse_min,
                                             half_T_side_sparse_max=config.half_T_side_sparse_max)

                if avg_loss_batch is None:
                    avg_loss_batch = loss
                    avg_loss_no_lamb_batch = loss_no_lamb
                else:
                    avg_loss_batch += loss
                    avg_loss_no_lamb_batch += loss_no_lamb

            avg_loss_batch /= config.num_sub_heads
            avg_loss_no_lamb_batch /= config.num_sub_heads

            if ((b_i % 100) == 0) or (current_epoch == next_epoch):
                print(
                    "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
                    "lamb %f "
                    "time %s" %
                    (config.model_ind, current_epoch, head, b_i, avg_loss_batch.item(),
                     avg_loss_no_lamb_batch.item(), datetime.now()))
                sys.stdout.flush()

            if not np.isfinite(avg_loss_batch.item()):
                print("Loss is not finite... %s:" % str(avg_loss_batch))
                exit(1)

            avg_loss += avg_loss_batch.item()
            avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
            avg_loss_count += 1

            if not config.nocuda:
                avg_loss_batch = avg_loss_batch.cuda()

            avg_loss_batch.backward()
            optimiser.step()

            torch.cuda.empty_cache() if not config.nocuda else None

            b_i += 1
            if b_i == 2 and config.test_code:
                break

        avg_loss = float(avg_loss / avg_loss_count)
        avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

        epoch_loss.append(avg_loss)
        epoch_loss_no_lamb.append(avg_loss_no_lamb)

    return net
Example #12
0
def train(config, net, optimiser, render_count=-1):
    # TODO: center crop ok or does it remove too much information if text is aligned left
    dataloader_list, mapping_assignment_dataloader, mapping_test_dataloader = get_dataloader_list(
        config)

    num_heads = len(config.output_ks)

    # Results ----------------------------------------------------------------------

    if config.restart:
        if not config.restart_from_best:
            next_epoch = config.last_epoch + 1  # corresponds to last saved model
        else:
            next_epoch = np.argmax(np.array(config.epoch_acc)) + 1
        print("starting from epoch %d" % next_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        if config.double_eval:
            config.double_eval_acc = config.double_eval_acc[:next_epoch]
            config.double_eval_avg_subhead_acc = config.double_eval_avg_subhead_acc[:
                                                                                    next_epoch]
            config.double_eval_stats = config.double_eval_stats[:next_epoch]

        for i, loss in enumerate(config.epoch_loss):
            config.epoch_loss[i] = loss[:(next_epoch - 1)]
        for i, loss_no_lamb in enumerate(config.epoch_loss_no_lamb):
            config.epoch_loss_no_lamb[i] = loss_no_lamb[:(next_epoch - 1)]

    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        if config.double_eval:
            config.double_eval_acc = []
            config.double_eval_avg_subhead_acc = []
            config.double_eval_stats = []

        config.epoch_loss = [[] for _ in range(num_heads)]
        config.epoch_loss_no_lamb = [[] for _ in range(num_heads)]

        subhead = None
        if config.select_subhead_on_loss:
            assert num_heads == 2
            subhead = get_subhead_using_loss(config,
                                             dataloader_list[1],
                                             net,
                                             sobel=config.sobel,
                                             lamb=config.lamb)

        _ = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=config.sobel,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(2 + 2 * num_heads + 2 * int(config.double_eval),
                              sharex=False,
                              figsize=(20, 20))

    save_progression = hasattr(config,
                               "save_progression") and config.save_progression
    if save_progression:
        save_progression_count = 0
        save_progress(config,
                      net,
                      mapping_assignment_dataloader,
                      mapping_test_dataloader,
                      save_progression_count,
                      sobel=config.sobel,
                      render_count=render_count)
        save_progression_count += 1

    # Train ------------------------------------------------------------------------

    heads = range(num_heads)
    if config.reverse_heads:
        heads = reversed(heads)

    for e_i in xrange(next_epoch, config.num_epochs + 1):
        print("Starting e_i: %d" % e_i)

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_idx in heads:
            dataloaders = dataloader_list[head_idx]
            epoch_loss = config.epoch_loss[head_idx]
            epoch_loss_no_lamb = config.epoch_loss_no_lamb[head_idx]

            avg_loss = 0.  # over heads and head_epochs (and subheads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for head_i_epoch in range(config.head_epochs[head_idx]):
                sys.stdout.flush()

                iterators = (d for d in dataloaders)

                b_i = 0
                for tup in itertools.izip(*iterators):
                    net.module.zero_grad()

                    in_channels = config.in_channels
                    if config.sobel:
                        # one less because this is before sobel
                        in_channels -= 1
                    all_imgs = torch.zeros(
                        (config.batch_sz, in_channels, config.input_sz[0],
                         config.input_sz[1])).cuda()
                    all_imgs_tf = torch.zeros(
                        (config.batch_sz, in_channels, config.input_sz[0],
                         config.input_sz[1])).cuda()

                    imgs_curr = tup[0][0]  # always the first
                    curr_batch_sz = imgs_curr.size(0)
                    for d_i in xrange(config.num_dataloaders):
                        imgs_tf_curr = tup[1 + d_i][0]  # from 2nd to last
                        assert (curr_batch_sz == imgs_tf_curr.size(0))

                        actual_batch_start = d_i * curr_batch_sz
                        actual_batch_end = actual_batch_start + curr_batch_sz
                        all_imgs[actual_batch_start:
                                 actual_batch_end, :, :, :] = imgs_curr.cuda()
                        all_imgs_tf[
                            actual_batch_start:
                            actual_batch_end, :, :, :] = imgs_tf_curr.cuda()

                    if not (curr_batch_sz == config.dataloader_batch_sz):
                        print("last batch sz %d" % curr_batch_sz)

                    curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
                    all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
                    all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

                    if config.sobel:
                        all_imgs = sobel_process(all_imgs, config.include_rgb)
                        all_imgs_tf = sobel_process(all_imgs_tf,
                                                    config.include_rgb)

                    x_outs = net(all_imgs, head_idx=head_idx)
                    x_tf_outs = net(all_imgs_tf, head_idx=head_idx)

                    avg_loss_batch = None  # avg over the subheads
                    avg_loss_no_lamb_batch = None
                    for i in xrange(config.num_subheads):
                        loss, loss_no_lamb = IID_loss(x_outs[i],
                                                      x_tf_outs[i],
                                                      lamb=config.lamb)
                        if avg_loss_batch is None:
                            avg_loss_batch = loss
                            avg_loss_no_lamb_batch = loss_no_lamb
                        else:
                            avg_loss_batch += loss
                            avg_loss_no_lamb_batch += loss_no_lamb

                    avg_loss_batch /= config.num_subheads
                    avg_loss_no_lamb_batch /= config.num_subheads

                    if ((b_i % 100) == 0) or (e_i == next_epoch and b_i < 10):
                        print(
                            "Model ind %d epoch %d head %s head_i_epoch %d batch %d: avg loss %f avg loss no lamb %f "
                            "time %s" %
                            (config.model_ind, e_i, str(head_idx),
                             head_i_epoch, b_i, avg_loss_batch.item(),
                             avg_loss_no_lamb_batch.item(), datetime.now()))
                        sys.stdout.flush()

                    if not np.isfinite(avg_loss_batch.item()):
                        print("Loss is not finite... %s:" %
                              str(avg_loss_batch))
                        sys.exit(1)

                    avg_loss += avg_loss_batch.item()
                    avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                    avg_loss_count += 1

                    avg_loss_batch.backward()
                    optimiser.step()

                    if ((b_i % 50) == 0) and save_progression:
                        save_progress(config,
                                      net,
                                      mapping_assignment_dataloader,
                                      mapping_test_dataloader,
                                      save_progression_count,
                                      sobel=False,
                                      render_count=render_count)
                        save_progression_count += 1

                    b_i += 1
                    if b_i == 2 and config.test_code:
                        break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval -----------------------------------------------------------------------

        # Can also pick the subhead using the evaluation process (to do this, set use_subhead=None)
        subhead = None
        if config.select_subhead_on_loss:
            assert num_heads == 2
            subhead = get_subhead_using_loss(config,
                                             dataloader_list[1],
                                             net,
                                             sobel=config.sobel,
                                             lamb=config.lamb)

        is_best = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=config.sobel,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        last_ax_idx = 1
        starting_ax_idx = last_ax_idx + 1
        for i in range(num_heads):
            axarr[starting_ax_idx + i].clear()
            axarr[starting_ax_idx + i].plot(config.epoch_loss[i])
            axarr[starting_ax_idx + i].set_title("Loss head_idx " + str(i))

            axarr[starting_ax_idx + i + 1].clear()
            axarr[starting_ax_idx + i + 1].plot(config.epoch_loss_no_lamb[i])
            axarr[starting_ax_idx + i + 1].set_title("Loss no lamb head_idx " +
                                                     str(i))

        if config.double_eval:
            next_index = starting_ax_idx + 2 * num_heads
            axarr[next_index].clear()
            axarr[next_index].plot(config.double_eval_acc)
            axarr[next_index].set_title("double eval acc (best), top: %f" %
                                        max(config.double_eval_acc))

            axarr[next_index + 1].clear()
            axarr[next_index + 1].plot(config.double_eval_avg_subhead_acc)
            axarr[next_index +
                  1].set_title("double eval acc (avg), top: %f" %
                               max(config.double_eval_avg_subhead_acc))

        fig.tight_layout()
        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()

            if e_i % config.save_freq == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "latest_optimiser.pytorch"))
                config.last_epoch = e_i  # for last saved version

            if is_best:
                # also serves as backup if hardware fails - less likely to hit this
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "best_optimiser.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)
Example #13
0
def train_kmeans(config, net, test_dataloader):
    num_imgs = len(test_dataloader.dataset)
    max_num_pixels_per_img = int(config.max_num_kmeans_samples / num_imgs)

    features_all = np.zeros(
        (config.max_num_kmeans_samples, net.module.features_sz),
        dtype=np.float32)

    actual_num_features = 0

    # discard the label information in the dataloader
    for i, tup in enumerate(test_dataloader):
        if (config.verbose and i < 10) or (i % int(len(test_dataloader) / 10) == 0):
            print("(kmeans_segmentation_eval) batch %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        imgs, _, mask = tup  # test dataloader, cpu tensors
        imgs = imgs.cuda()
        mask = mask.numpy().astype(np.bool)
        # mask = mask.numpy().astype(np.bool)
        num_unmasked = mask.sum()

        if not config.no_sobel:
            imgs = sobel_process(imgs, config.include_rgb,
                                 using_IR=config.using_IR)
            # now rgb(ir) and/or sobel

        with torch.no_grad():
            # penultimate = features
            x_out = net(imgs, penultimate=True).cpu().numpy()

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) through model %d time %s" % (i,
                                                                           datetime.now()))
            sysout.flush()

        num_imgs_batch = x_out.shape[0]
        x_out = x_out.transpose((0, 2, 3, 1))  # features last

        x_out = x_out[mask, :]

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) applied mask %d time %s" % (i,
                                                                          datetime.now()))
            sysout.flush()

        if i == 0:
            assert (x_out.shape[1] == net.module.features_sz)
            assert (x_out.shape[0] == num_unmasked)

        # select pixels randomly, and record how many selected
        num_selected = min(num_unmasked, num_imgs_batch *
                           max_num_pixels_per_img)
        selected = np.random.choice(num_selected, replace=False)

        x_out = x_out[selected, :]

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) applied select %d time %s" % (i,
                                                                            datetime.now()))
            sysout.flush()

        features_all[actual_num_features:actual_num_features + num_selected, :] = \
            x_out

        actual_num_features += num_selected

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) stored %d time %s" % (i,
                                                                    datetime.now()))
            sysout.flush()

    assert (actual_num_features <= config.max_num_kmeans_samples)
    features_all = features_all[:actual_num_features, :]

    if config.verbose:
        print("running kmeans")
        sysout.flush()
    kmeans = MiniBatchKMeans(n_clusters=config.gt_k, verbose=config.verbose).fit(
        features_all)

    return kmeans
Example #14
0
                    actual_batch_start = d_i * curr_batch_sz
                    actual_batch_end = actual_batch_start + curr_batch_sz
                    all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \
                        imgs_curr.cuda()
                    all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \
                        imgs_tf_curr.cuda()

                if not (curr_batch_sz == config.dataloader_batch_sz):
                    print("last batch sz %d" % curr_batch_sz)

                curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
                all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
                all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

                all_imgs = sobel_process(all_imgs,
                                         config.include_rgb,
                                         cuda_enabled=not config.nocuda)
                all_imgs_tf = sobel_process(all_imgs_tf,
                                            config.include_rgb,
                                            cuda_enabled=not config.nocuda)

                x_outs = net(all_imgs, head=head)
                x_tf_outs = net(all_imgs_tf, head=head)

                avg_loss_batch = None  # avg over the sub_heads
                avg_loss_no_lamb_batch = None
                for i in range(config.num_sub_heads):
                    loss, loss_no_lamb = IID_loss(x_outs[i],
                                                  x_tf_outs[i],
                                                  lamb=config.lamb)
                    if avg_loss_batch is None:
Example #15
0
def _segmentation_get_data(config,
                           net,
                           dataloader,
                           sobel=False,
                           using_IR=False,
                           verbose=0):
    # returns (vectorised) cuda tensors for flat preds and targets
    # sister of _clustering_get_data

    assert (config.output_k <= 255)

    num_batches = len(dataloader)
    num_samples = 0

    # upper bound, will be less for last batch
    samples_per_batch = config.batch_sz * config.input_sz * config.input_sz

    if verbose > 0:
        print("started _segmentation_get_data %s" % datetime.now())
        sys.stdout.flush()

    # vectorised
    flat_predss_all = [
        torch.zeros((num_batches * samples_per_batch),
                    dtype=torch.uint8).cuda()
        for _ in range(config.num_sub_heads)
    ]
    flat_targets_all = torch.zeros((num_batches * samples_per_batch),
                                   dtype=torch.uint8).cuda()
    mask_all = torch.zeros((num_batches * samples_per_batch),
                           dtype=torch.uint8).cuda()

    if verbose > 0:
        batch_start = datetime.now()
        all_start = batch_start
        print("starting batches %s" % batch_start)

    for b_i, batch in enumerate(dataloader):

        imgs, flat_targets, mask = batch
        imgs = imgs.cuda()

        if sobel:
            imgs = sobel_process(imgs, config.include_rgb, using_IR=using_IR)

        with torch.no_grad():
            x_outs = net(imgs)

        assert (x_outs[0].shape[1] == config.output_k)
        assert (x_outs[0].shape[2] == config.input_sz
                and x_outs[0].shape[3] == config.input_sz)

        # actual batch size
        actual_samples_curr = (flat_targets.shape[0] * config.input_sz *
                               config.input_sz)
        num_samples += actual_samples_curr

        # vectorise: collapse from 2D to 1D
        start_i = b_i * samples_per_batch
        for i in range(config.num_sub_heads):
            x_outs_curr = x_outs[i]
            assert (not x_outs_curr.requires_grad)
            flat_preds_curr = torch.argmax(x_outs_curr, dim=1)
            flat_predss_all[i][start_i:(
                start_i + actual_samples_curr)] = flat_preds_curr.view(-1)

        flat_targets_all[start_i:(start_i +
                                  actual_samples_curr)] = flat_targets.view(-1)
        mask_all[start_i:(start_i + actual_samples_curr)] = mask.view(-1)

        if verbose > 0 and b_i < 3:
            batch_finish = datetime.now()
            print("finished batch %d, %s, took %s, of %d" %
                  (b_i, batch_finish, batch_finish - batch_start, num_batches))
            batch_start = batch_finish
            sys.stdout.flush()

    if verbose > 0:
        all_finish = datetime.now()
        print("finished all batches %s, took %s" %
              (all_finish, all_finish - all_start))
        sys.stdout.flush()

    flat_predss_all = [
        flat_predss_all[i][:num_samples] for i in range(config.num_sub_heads)
    ]
    flat_targets_all = flat_targets_all[:num_samples]
    mask_all = mask_all[:num_samples]

    flat_predss_all = [
        flat_predss_all[i].masked_select(mask=mask_all)
        for i in range(config.num_sub_heads)
    ]
    flat_targets_all = flat_targets_all.masked_select(mask=mask_all)

    if verbose > 0:
        print("ended _segmentation_get_data %s" % datetime.now())
        sys.stdout.flush()

    selected_samples = mask_all.sum()
    assert (len(flat_predss_all[0].shape) == 1
            and len(flat_targets_all.shape) == 1)
    assert (flat_predss_all[0].shape[0] == selected_samples)
    assert (flat_targets_all.shape[0] == selected_samples)

    return flat_predss_all, flat_targets_all
            actual_batch_start = d_i * curr_batch_sz
            actual_batch_end = actual_batch_start + curr_batch_sz
            all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \
                imgs_curr.cuda()
            all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \
                imgs_tf_curr.cuda()

        if not (curr_batch_sz == config.dataloader_batch_sz):
            print("last batch sz %d" % curr_batch_sz)

        curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
        all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
        all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

        all_imgs = sobel_process(all_imgs, config.include_rgb)
        all_imgs_tf = sobel_process(all_imgs_tf, config.include_rgb)

        x_outs = net(all_imgs)
        x_tf_outs = net(all_imgs_tf)

        avg_loss_batch = None  # avg over the heads
        avg_loss_no_lamb_batch = None
        for i in range(config.num_sub_heads):
            loss, loss_no_lamb = IID_loss(
                x_outs[i], x_tf_outs[i], lamb=config.lamb)
            if avg_loss_batch is None:
                avg_loss_batch = loss
                avg_loss_no_lamb_batch = loss_no_lamb
            else:
                avg_loss_batch += loss
Example #17
0
      all_img2[actual_batch_start:actual_batch_end, :, :, :] = img2
      all_affine2_to_1[actual_batch_start:actual_batch_end, :, :] = affine2_to_1
      all_mask_img1[actual_batch_start:actual_batch_end, :, :] = mask_img1

    if not (curr_batch_sz == config.dataloader_batch_sz) and (
        e_i == next_epoch):
      print("last batch sz %d" % curr_batch_sz)

    curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
    all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
    all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
    all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
    all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

    if (not config.no_sobel):
      all_img1 = sobel_process(all_img1, config.include_rgb,
                               using_IR=config.using_IR)
      all_img2 = sobel_process(all_img2, config.include_rgb,
                               using_IR=config.using_IR)

    x1_outs = net(all_img1)
    x2_outs = net(all_img2)

    avg_loss_batch = None  # avg over the heads
    avg_loss_no_lamb_batch = None
    for i in xrange(config.num_sub_heads):
      loss, loss_no_lamb = loss_fn(x1_outs[i],
                                   x2_outs[i],
                                   all_affine2_to_1=all_affine2_to_1,
                                   all_mask_img1=all_mask_img1,
                                   lamb=config.lamb,
                                   half_T_side_dense=config.half_T_side_dense,
            print(("length of imgs_dataloader %d" % len(imgs_dataloader)))

            next_img_ind = 0

            for b_i, batch in enumerate(imgs_dataloader):
                orig_imgs, flat_targets, mask = batch
                orig_imgs, flat_targets, mask = (
                    orig_imgs.cuda(),
                    flat_targets.numpy(),
                    mask.numpy().astype(np.bool),
                )

                if not config.no_sobel:
                    imgs = sobel_process(orig_imgs,
                                         config.include_rgb,
                                         using_IR=config.using_IR)
                else:
                    imgs = orig_imgs

                with torch.no_grad():
                    x_outs_all = net(imgs)

                x_outs = x_outs_all[head_i]
                x_outs = x_outs.cpu().numpy()
                flat_preds = np.argmax(x_outs, axis=1)
                n, h, w = flat_preds.shape

                num_imgs_curr = flat_preds.shape[0]

                reordered_preds = np.zeros((num_imgs_curr, h, w),
    net.module.train()
    is_best = False

    if e_i in config.lr_schedule:
        optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    avg_loss = 0.  # over epoch

    for b_i, tup in enumerate(dataloader):
        net.module.zero_grad()

        img, mask = tup  # cuda

        # no need for requires_grad or Variable (torch 0.4.1)
        if (not config.no_sobel):
            img = sobel_process(img, config.include_rgb,
                                using_IR=config.using_IR)

        centre, other, position_gt = doersch_set_patches(input_sz=config.input_sz,
                                                         patch_side=config.doersch_patch_side)
        position_pred = net(img, centre=centre, other=other)

        loss = doersch_loss(position_pred, centre, other, position_gt, mask,
                            crossent=crossent,
                            verbose=config.verbose)

        if ((b_i % 100) == 0) or (e_i == next_epoch):
            print("Model ind %d epoch %d batch: %d loss %f "
                  "time %s" %
                  (config.model_ind, e_i, b_i, float(loss.item()), datetime.now()))
            sys.stdout.flush()
Example #20
0
def train():
    dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = \
      segmentation_create_dataloaders(config)
    dataloaders_head_B = dataloaders_head_A  # unlike for clustering datasets

    net = archs.__dict__[config.arch](config)
    if config.restart:
        dict = torch.load(os.path.join(config.out_dir, dict_name),
                          map_location=lambda storage, loc: storage)
        net.load_state_dict(dict["net"])
    net.cuda()
    net = torch.nn.DataParallel(net)
    net.train()

    optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
    if config.restart:
        optimiser.load_state_dict(dict["optimiser"])

    heads = ["A", "B"]
    if hasattr(config, "head_B_first") and config.head_B_first:
        heads = ["B", "A"]

    # Results
    # ----------------------------------------------------------------------

    if config.restart:
        next_epoch = config.last_epoch + 1
        print("starting from epoch %d" % next_epoch)

        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:(
            next_epoch - 1)]
        config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:(
            next_epoch - 1)]
    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        config.epoch_loss_head_A = []
        config.epoch_loss_no_lamb_head_A = []

        config.epoch_loss_head_B = []
        config.epoch_loss_no_lamb_head_B = []

        _ = segmentation_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=(not config.no_sobel),
            using_IR=config.using_IR)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

    if not config.use_uncollapsed_loss:
        print("using condensed loss (default)")
        loss_fn = IID_segmentation_loss
    else:
        print("using uncollapsed loss!")
        loss_fn = IID_segmentation_loss_uncollapsed

    # Train
    # ------------------------------------------------------------------------

    for e_i in xrange(next_epoch, config.num_epochs):
        print("Starting e_i: %d %s" % (e_i, datetime.now()))
        sys.stdout.flush()

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_i in range(2):
            head = heads[head_i]
            if head == "A":
                dataloaders = dataloaders_head_A
                epoch_loss = config.epoch_loss_head_A
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
                lamb = config.lamb_A

            elif head == "B":
                dataloaders = dataloaders_head_B
                epoch_loss = config.epoch_loss_head_B
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
                lamb = config.lamb_B

            iterators = (d for d in dataloaders)
            b_i = 0
            avg_loss = 0.  # over heads and head_epochs (and subheads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for tup in itertools.izip(*iterators):
                net.module.zero_grad()

                if not config.no_sobel:
                    pre_channels = config.in_channels - 1
                else:
                    pre_channels = config.in_channels

                all_img1 = torch.zeros(config.batch_sz, pre_channels,
                                       config.input_sz, config.input_sz).to(
                                           torch.float32).cuda()
                all_img2 = torch.zeros(config.batch_sz, pre_channels,
                                       config.input_sz, config.input_sz).to(
                                           torch.float32).cuda()
                all_affine2_to_1 = torch.zeros(config.batch_sz, 2,
                                               3).to(torch.float32).cuda()
                all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz,
                                            config.input_sz).to(
                                                torch.float32).cuda()

                curr_batch_sz = tup[0][0].shape[0]
                for d_i in xrange(config.num_dataloaders):
                    img1, img2, affine2_to_1, mask_img1 = tup[d_i]
                    assert (img1.shape[0] == curr_batch_sz)

                    actual_batch_start = d_i * curr_batch_sz
                    actual_batch_end = actual_batch_start + curr_batch_sz

                    all_img1[
                        actual_batch_start:actual_batch_end, :, :, :] = img1
                    all_img2[
                        actual_batch_start:actual_batch_end, :, :, :] = img2
                    all_affine2_to_1[actual_batch_start:
                                     actual_batch_end, :, :] = affine2_to_1
                    all_mask_img1[
                        actual_batch_start:actual_batch_end, :, :] = mask_img1

                if not (curr_batch_sz
                        == config.dataloader_batch_sz) and (e_i == next_epoch):
                    print("last batch sz %d" % curr_batch_sz)

                curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  # times 2
                all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
                all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
                all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
                all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

                if (not config.no_sobel):
                    all_img1 = sobel_process(all_img1,
                                             config.include_rgb,
                                             using_IR=config.using_IR)
                    all_img2 = sobel_process(all_img2,
                                             config.include_rgb,
                                             using_IR=config.using_IR)

                x1_outs = net(all_img1, head=head)
                x2_outs = net(all_img2, head=head)

                avg_loss_batch = None  # avg over the heads
                avg_loss_no_lamb_batch = None

                for i in xrange(config.num_subheads):
                    loss, loss_no_lamb = loss_fn(
                        x1_outs[i],
                        x2_outs[i],
                        all_affine2_to_1=all_affine2_to_1,
                        all_mask_img1=all_mask_img1,
                        lamb=lamb,
                        half_T_side_dense=config.half_T_side_dense,
                        half_T_side_sparse_min=config.half_T_side_sparse_min,
                        half_T_side_sparse_max=config.half_T_side_sparse_max)

                    if avg_loss_batch is None:
                        avg_loss_batch = loss
                        avg_loss_no_lamb_batch = loss_no_lamb
                    else:
                        avg_loss_batch += loss
                        avg_loss_no_lamb_batch += loss_no_lamb

                avg_loss_batch /= config.num_subheads
                avg_loss_no_lamb_batch /= config.num_subheads

                if ((b_i % 100) == 0) or (e_i == next_epoch):
                    print(
                      "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
                      "lamb %f "
                      "time %s" % \
                      (config.model_ind, e_i, head, b_i, avg_loss_batch.item(),
                       avg_loss_no_lamb_batch.item(), datetime.now()))
                    sys.stdout.flush()

                if not np.isfinite(avg_loss_batch.item()):
                    print("Loss is not finite... %s:" % str(avg_loss_batch))
                    exit(1)

                avg_loss += avg_loss_batch.item()
                avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                avg_loss_count += 1

                avg_loss_batch.backward()
                optimiser.step()

                torch.cuda.empty_cache()

                b_i += 1
                if b_i == 2 and config.test_code:
                    break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval
        # -----------------------------------------------------------------------

        is_best = segmentation_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=(not config.no_sobel),
            using_IR=config.using_IR)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        axarr[2].clear()
        axarr[2].plot(config.epoch_loss_head_A)
        axarr[2].set_title("Loss head A")

        axarr[3].clear()
        axarr[3].plot(config.epoch_loss_no_lamb_head_A)
        axarr[3].set_title("Loss no lamb head A")

        axarr[4].clear()
        axarr[4].plot(config.epoch_loss_head_B)
        axarr[4].set_title("Loss head B")

        axarr[5].clear()
        axarr[5].plot(config.epoch_loss_no_lamb_head_B)
        axarr[5].set_title("Loss no lamb head B")

        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()
            save_dict = {
                "net": net.module.state_dict(),
                "optimiser": optimiser.state_dict()
            }

            if e_i % config.save_freq == 0:
                torch.save(save_dict,
                           os.path.join(config.out_dir, "latest.pytorch"))
                config.last_epoch = e_i  # for last saved version

            if is_best:
                torch.save(save_dict,
                           os.path.join(config.out_dir, "best.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)
Example #21
0
def apply_trained_kmeans(config, net, test_dataloader, kmeans):
    if config.verbose:
        print("starting inference")
        sysout.flush()

    # on the entire test dataset
    num_imgs = len(test_dataloader.dataset)
    max_num_samples = num_imgs * config.input_sz * config.input_sz
    preds_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()
    targets_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()

    actual_num_unmasked = 0

    # discard the label information in the dataloader
    for i, tup in enumerate(test_dataloader):
        if (config.verbose and i < 10) or (i % int(len(test_dataloader) / 10) == 0):
            print("(apply_trained_kmeans) batch %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        imgs, targets, mask = tup  # test dataloader, cpu tensors
        imgs, mask_cuda, targets, mask_np = imgs.cuda(), mask.cuda(), \
            targets.cuda(), mask.numpy().astype(
            np.bool)
        num_unmasked = mask_cuda.sum().item()

        if not config.no_sobel:
            imgs = sobel_process(imgs, config.include_rgb,
                                 using_IR=config.using_IR)
            # now rgb(ir) and/or sobel

        with torch.no_grad():
            # penultimate = features
            x_out = net(imgs, penultimate=True).cpu().numpy()

        x_out = x_out.transpose((0, 2, 3, 1))  # features last
        x_out = x_out[mask_np, :]
        targets = targets.masked_select(mask_cuda)  # can do because flat

        assert (x_out.shape == (num_unmasked, net.module.features_sz))
        preds = torch.from_numpy(kmeans.predict(x_out)).cuda()

        preds_all[actual_num_unmasked: actual_num_unmasked +
                  num_unmasked] = preds
        targets_all[
            actual_num_unmasked: actual_num_unmasked + num_unmasked] = targets

        actual_num_unmasked += num_unmasked

    preds_all = preds_all[:actual_num_unmasked]
    targets_all = targets_all[:actual_num_unmasked]

    torch.cuda.empty_cache()

    # permutation, not many-to-one
    match = _hungarian_match(preds_all, targets_all, preds_k=config.gt_k,
                             targets_k=config.gt_k)
    torch.cuda.empty_cache()

    # do in cpu because of RAM
    reordered_preds = torch.zeros(actual_num_unmasked, dtype=preds_all.dtype)
    for pred_i, target_i in match:
        selected = (preds_all == pred_i).cpu()
        reordered_preds[selected] = target_i

    reordered_preds = reordered_preds.cuda()

    # this checks values
    acc = _acc(reordered_preds, targets_all,
               config.gt_k, verbose=config.verbose)

    if GET_NMI_ARI:
        nmi, ari = _nmi(reordered_preds, targets_all), \
            _ari(reordered_preds, targets_all)
    else:
        nmi, ari = -1., -1.

    reordered_masses = np.zeros(config.gt_k)
    for c in range(config.gt_k):
        reordered_masses[c] = float(
            (reordered_preds == c).sum()) / actual_num_unmasked

    return acc, nmi, ari, reordered_masses