Exemple #1
0
def triplets_get_data(config, net, dataloader, sobel):
  num_batches = len(dataloader)
  flat_targets_all = torch.zeros((num_batches * config.batch_sz),
                                 dtype=torch.int32).cuda()
  flat_preds_all = torch.zeros((num_batches * config.batch_sz),
                               dtype=torch.int32).cuda()

  num_test = 0
  for b_i, batch in enumerate(dataloader):
    imgs = batch[0].cuda()

    if sobel:
      imgs = sobel_process(imgs, config.include_rgb)

    flat_targets = batch[1]

    with torch.no_grad():
      x_outs = net(imgs)

    assert (x_outs.shape[1] == config.output_k)
    assert (len(x_outs.shape) == 2)

    num_test_curr = flat_targets.shape[0]
    num_test += num_test_curr

    start_i = b_i * config.batch_sz
    flat_preds_curr = torch.argmax(x_outs, dim=1)  # along output_k
    flat_preds_all[start_i:(start_i + num_test_curr)] = flat_preds_curr

    flat_targets_all[start_i:(start_i + num_test_curr)] = flat_targets.cuda()

  flat_preds_all = flat_preds_all[:num_test]
  flat_targets_all = flat_targets_all[:num_test]

  return flat_preds_all, flat_targets_all
Exemple #2
0
def assess_acc_block(net,
                     test_loader,
                     gt_k=None,
                     include_rgb=None,
                     penultimate_features=False,
                     contiguous_sz=None):
    total = 0
    all = None
    all_targets = None
    for i, (imgs, targets) in enumerate(test_loader):
        imgs = Variable(sobel_process(imgs.cuda(), include_rgb))

        with torch.no_grad():
            x_out = net(imgs, penultimate_features=penultimate_features)

        bn, dlen = x_out.shape
        if all is None:
            all = np.zeros((len(test_loader) * bn, dlen))
            all_targets = np.zeros(len(test_loader) * bn)

        all[total:(total + bn), :] = x_out.cpu().numpy()
        all_targets[total:(total + bn)] = targets.numpy()
        total += bn

    # 40000
    all = all[:total, :]
    all_targets = all_targets[:total]

    num_orig, leftover = divmod(total, contiguous_sz)
    assert (leftover == 0)

    all = all.reshape((num_orig, contiguous_sz, dlen))
    all = all.sum(axis=1, keepdims=False) / float(contiguous_sz)

    all_targets = all_targets.reshape((num_orig, contiguous_sz))
    # sanity check
    all_targets_avg = all_targets.astype("int").sum(axis=1) / contiguous_sz
    all_targets = all_targets[:, 0].astype("int")
    assert (np.array_equal(all_targets_avg, all_targets))

    preds = np.argmax(all, axis=1).astype("int")
    assert (preds.min() >= 0 and preds.max() < gt_k)
    assert (all_targets.min() >= 0 and all_targets.max() < gt_k)
    if not (preds.shape == all_targets.shape):
        print(preds.shape)
        print(all_targets.shape)
        assert (False)

    assert (preds.shape == (num_orig, ))
    correct = (preds == all_targets).sum()

    return correct / float(num_orig)
Exemple #3
0
def get_dlen(net_features,
             dataloader,
             include_rgb=None,
             penultimate_features=False):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = Variable(sobel_process(imgs.cuda(), include_rgb)).cpu()
        x_features = net_features(imgs,
                                  trunk_features=True,
                                  penultimate_features=penultimate_features)

        x_features = x_features.view(x_features.shape[0], -1)
        dlen = x_features.shape[1]
        break

    return dlen
Exemple #4
0
def triplets_get_data_kmeans_on_features(config, net, dataloader, sobel):
  # ouput of network is features (not softmaxed)
  num_batches = len(dataloader)
  flat_targets_all = torch.zeros((num_batches * config.batch_sz),
                                 dtype=torch.int32).cuda()
  features_all = np.zeros((num_batches * config.batch_sz, config.output_k),
                          dtype=np.float32)

  num_test = 0
  for b_i, batch in enumerate(dataloader):
    imgs = batch[0].cuda()

    if sobel:
      imgs = sobel_process(imgs, config.include_rgb)

    flat_targets = batch[1]

    with torch.no_grad():
      x_outs = net(imgs)

    assert (x_outs.shape[1] == config.output_k)
    assert (len(x_outs.shape) == 2)

    num_test_curr = flat_targets.shape[0]
    num_test += num_test_curr

    start_i = b_i * config.batch_sz
    features_all[start_i:(start_i + num_test_curr), :] = x_outs.cpu().numpy()
    flat_targets_all[start_i:(start_i + num_test_curr)] = flat_targets.cuda()

  features_all = features_all[:num_test, :]
  flat_targets_all = flat_targets_all[:num_test]

  kmeans = KMeans(n_clusters=config.gt_k).fit(features_all)
  flat_preds_all = torch.from_numpy(kmeans.labels_).cuda()

  assert (flat_targets_all.shape == flat_preds_all.shape)
  assert (max(flat_preds_all) < config.gt_k)

  return flat_preds_all, flat_targets_all
Exemple #5
0
def assess_acc(net,
               test_loader,
               gt_k=None,
               include_rgb=None,
               penultimate_features=False):
    correct = 0
    total = 0
    for i, (imgs, targets) in enumerate(test_loader):
        imgs = Variable(sobel_process(imgs.cuda(), include_rgb))

        with torch.no_grad():
            x_out = net(imgs, penultimate_features=penultimate_features)

        # bug fix!!
        preds = np.argmax(x_out.cpu().numpy(), axis=1).astype("int")
        targets = targets.numpy().astype("int")
        assert (preds.min() >= 0 and preds.max() < gt_k)
        assert (targets.min() >= 0 and targets.max() < gt_k)
        assert (preds.shape == targets.shape)

        correct += (preds == targets).sum()
        total += preds.shape[0]

    return correct / float(total)
def train_kmeans(config, net, test_dataloader):
    num_imgs = len(test_dataloader.dataset)
    max_num_pixels_per_img = int(config.max_num_kmeans_samples / num_imgs)

    features_all = np.zeros(
        (config.max_num_kmeans_samples, net.module.features_sz),
        dtype=np.float32)

    actual_num_features = 0

    # discard the label information in the dataloader
    for i, tup in enumerate(test_dataloader):
        if (config.verbose and i < 10) or (i % int(len(test_dataloader) / 10)
                                           == 0):
            print("(kmeans_segmentation_eval) batch %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        imgs, _, mask = tup  # test dataloader, cpu tensors
        imgs = imgs.cuda()
        mask = mask.numpy().astype(np.bool)
        # mask = mask.numpy().astype(np.bool)
        num_unmasked = mask.sum()

        if not config.no_sobel:
            imgs = sobel_process(imgs,
                                 config.include_rgb,
                                 using_IR=config.using_IR)
            # now rgb(ir) and/or sobel

        with torch.no_grad():
            # penultimate = features
            x_out = net(imgs, penultimate=True).cpu().numpy()

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) through model %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        num_imgs_batch = x_out.shape[0]
        x_out = x_out.transpose((0, 2, 3, 1))  # features last

        x_out = x_out[mask, :]

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) applied mask %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        if i == 0:
            assert (x_out.shape[1] == net.module.features_sz)
            assert (x_out.shape[0] == num_unmasked)

        # select pixels randomly, and record how many selected
        num_selected = min(num_unmasked,
                           num_imgs_batch * max_num_pixels_per_img)
        selected = np.random.choice(num_selected, replace=False)

        x_out = x_out[selected, :]

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) applied select %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        features_all[actual_num_features:actual_num_features + num_selected, :] = \
          x_out

        actual_num_features += num_selected

        if config.verbose and i < 2:
            print("(kmeans_segmentation_eval) stored %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

    assert (actual_num_features <= config.max_num_kmeans_samples)
    features_all = features_all[:actual_num_features, :]

    if config.verbose:
        print("running kmeans")
        sysout.flush()
    kmeans = MiniBatchKMeans(n_clusters=config.gt_k,
                             verbose=config.verbose).fit(features_all)

    return kmeans
def apply_trained_kmeans(config, net, test_dataloader, kmeans):
    if config.verbose:
        print("starting inference")
        sysout.flush()

    # on the entire test dataset
    num_imgs = len(test_dataloader.dataset)
    max_num_samples = num_imgs * config.input_sz * config.input_sz
    preds_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()
    targets_all = torch.zeros(max_num_samples, dtype=torch.int32).cuda()

    actual_num_unmasked = 0

    # discard the label information in the dataloader
    for i, tup in enumerate(test_dataloader):
        if (config.verbose and i < 10) or (i % int(len(test_dataloader) / 10)
                                           == 0):
            print("(apply_trained_kmeans) batch %d time %s" %
                  (i, datetime.now()))
            sysout.flush()

        imgs, targets, mask = tup  # test dataloader, cpu tensors
        imgs, mask_cuda, targets, mask_np = imgs.cuda(), mask.cuda(), \
                                            targets.cuda(), mask.numpy().astype(
          np.bool)
        num_unmasked = mask_cuda.sum().item()

        if not config.no_sobel:
            imgs = sobel_process(imgs,
                                 config.include_rgb,
                                 using_IR=config.using_IR)
            # now rgb(ir) and/or sobel

        with torch.no_grad():
            # penultimate = features
            x_out = net(imgs, penultimate=True).cpu().numpy()

        x_out = x_out.transpose((0, 2, 3, 1))  # features last
        x_out = x_out[mask_np, :]
        targets = targets.masked_select(mask_cuda)  # can do because flat

        assert (x_out.shape == (num_unmasked, net.module.features_sz))
        preds = torch.from_numpy(kmeans.predict(x_out)).cuda()

        preds_all[actual_num_unmasked:actual_num_unmasked +
                  num_unmasked] = preds
        targets_all[actual_num_unmasked:actual_num_unmasked +
                    num_unmasked] = targets

        actual_num_unmasked += num_unmasked

    preds_all = preds_all[:actual_num_unmasked]
    targets_all = targets_all[:actual_num_unmasked]

    torch.cuda.empty_cache()

    # permutation, not many-to-one
    match = _hungarian_match(preds_all,
                             targets_all,
                             preds_k=config.gt_k,
                             targets_k=config.gt_k)
    torch.cuda.empty_cache()

    # do in cpu because of RAM
    reordered_preds = torch.zeros(actual_num_unmasked, dtype=preds_all.dtype)
    for pred_i, target_i in match:
        selected = (preds_all == pred_i).cpu()
        reordered_preds[selected] = target_i

    reordered_preds = reordered_preds.cuda()

    # this checks values
    acc = _acc(reordered_preds,
               targets_all,
               config.gt_k,
               verbose=config.verbose)

    if GET_NMI_ARI:
        nmi, ari = _nmi(reordered_preds, targets_all), \
                   _ari(reordered_preds, targets_all)
    else:
        nmi, ari = -1., -1.

    reordered_masses = np.zeros(config.gt_k)
    for c in range(config.gt_k):
        reordered_masses[c] = float(
            (reordered_preds == c).sum()) / actual_num_unmasked

    return acc, nmi, ari, reordered_masses
def _segmentation_get_data(config, net, dataloader, sobel=False,
                           using_IR=False, verbose=0):
  # returns (vectorised) cuda tensors for flat preds and targets
  # sister of _clustering_get_data

  assert (config.output_k <= 255)

  num_batches = len(dataloader)
  num_samples = 0

  # upper bound, will be less for last batch
  samples_per_batch = config.batch_sz * config.input_sz * config.input_sz

  if verbose > 0:
    print("started _segmentation_get_data %s" % datetime.now())
    sys.stdout.flush()

  # vectorised
  flat_predss_all = [torch.zeros((num_batches * samples_per_batch),
                                 dtype=torch.uint8).cuda() for _ in range(
    config.num_sub_heads)]
  flat_targets_all = torch.zeros((num_batches * samples_per_batch),
                                 dtype=torch.uint8).cuda()
  mask_all = torch.zeros((num_batches * samples_per_batch),
                         dtype=torch.uint8).cuda()

  if verbose > 0:
    batch_start = datetime.now()
    all_start = batch_start
    print("starting batches %s" % batch_start)

  for b_i, batch in enumerate(dataloader):
    print(batch)
    imgs, flat_targets, mask = batch
    imgs = imgs.cuda()

    if sobel:
      imgs = sobel_process(imgs, config.include_rgb, using_IR=using_IR)

    with torch.no_grad():
      x_outs = net(imgs)

    assert (x_outs[0].shape[1] == config.output_k)
    assert (x_outs[0].shape[2] == config.input_sz and x_outs[0].shape[
      3] == config.input_sz)

    # actual batch size
    actual_samples_curr = (
      flat_targets.shape[0] * config.input_sz * config.input_sz)
    num_samples += actual_samples_curr

    # vectorise: collapse from 2D to 1D
    start_i = b_i * samples_per_batch
    for i in range(config.num_sub_heads):
      x_outs_curr = x_outs[i]
      assert (not x_outs_curr.requires_grad)
      flat_preds_curr = torch.argmax(x_outs_curr, dim=1)
      flat_predss_all[i][
      start_i:(start_i + actual_samples_curr)] = flat_preds_curr.view(-1)

    flat_targets_all[
    start_i:(start_i + actual_samples_curr)] = flat_targets.view(-1)
    mask_all[start_i:(start_i + actual_samples_curr)] = mask.view(-1)

    if verbose > 0 and b_i < 3:
      batch_finish = datetime.now()
      print("finished batch %d, %s, took %s, of %d" %
            (b_i, batch_finish, batch_finish - batch_start, num_batches))
      batch_start = batch_finish
      sys.stdout.flush()

  if verbose > 0:
    all_finish = datetime.now()
    print(
      "finished all batches %s, took %s" % (all_finish, all_finish - all_start))
    sys.stdout.flush()

  flat_predss_all = [flat_predss_all[i][:num_samples] for i in
                     range(config.num_sub_heads)]
  flat_targets_all = flat_targets_all[:num_samples]
  mask_all = mask_all[:num_samples]

  flat_predss_all = [flat_predss_all[i].masked_select(mask=mask_all) for i in
                     range(config.num_sub_heads)]
  flat_targets_all = flat_targets_all.masked_select(mask=mask_all)

  if verbose > 0:
    print("ended _segmentation_get_data %s" % datetime.now())
    sys.stdout.flush()

  selected_samples = mask_all.sum()
  assert (len(flat_predss_all[0].shape) == 1 and
          len(flat_targets_all.shape) == 1)
  assert (flat_predss_all[0].shape[0] == selected_samples)
  assert (flat_targets_all.shape[0] == selected_samples)

  return flat_predss_all, flat_targets_all
def train():
  dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = \
    segmentation_create_dataloaders(config)
  dataloaders_head_B = dataloaders_head_A  # unlike for clustering datasets

  net = archs.__dict__[config.arch](config)
  if config.restart:
    dict = torch.load(os.path.join(config.out_dir, dict_name),
                      map_location=lambda storage, loc: storage)
    net.load_state_dict(dict["net"])
  net.cuda()
  net = torch.nn.DataParallel(net)
  net.train()

  optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
  if config.restart:
    optimiser.load_state_dict(dict["optimiser"])

  heads = ["A", "B"]
  if hasattr(config, "head_B_first") and config.head_B_first:
    heads = ["B", "A"]

  # Results
  # ----------------------------------------------------------------------

  if config.restart:
    next_epoch = config.last_epoch + 1
    print("starting from epoch %d" % next_epoch)

    config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
    config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:next_epoch]
    config.epoch_stats = config.epoch_stats[:next_epoch]

    config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
    config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[
                                       :(next_epoch - 1)]
    config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
    config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[
                                       :(next_epoch - 1)]
  else:
    config.epoch_acc = []
    config.epoch_avg_subhead_acc = []
    config.epoch_stats = []

    config.epoch_loss_head_A = []
    config.epoch_loss_no_lamb_head_A = []

    config.epoch_loss_head_B = []
    config.epoch_loss_no_lamb_head_B = []

    _ = segmentation_eval(config, net,
                          mapping_assignment_dataloader=mapping_assignment_dataloader,
                          mapping_test_dataloader=mapping_test_dataloader,
                          sobel=(not config.no_sobel),
                          using_IR=config.using_IR)

    print(
      "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()
    next_epoch = 1

  fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

  if not config.use_uncollapsed_loss:
    print("using condensed loss (default)")
    loss_fn = IID_segmentation_loss
  else:
    print("using uncollapsed loss!")
    loss_fn = IID_segmentation_loss_uncollapsed

  # Train
  # ------------------------------------------------------------------------

  for e_i in xrange(next_epoch, config.num_epochs):
    print("Starting e_i: %d %s" % (e_i, datetime.now()))
    sys.stdout.flush()

    if e_i in config.lr_schedule:
      optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    for head_i in range(2):
      head = heads[head_i]
      if head == "A":
        dataloaders = dataloaders_head_A
        epoch_loss = config.epoch_loss_head_A
        epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
        lamb = config.lamb_A

      elif head == "B":
        dataloaders = dataloaders_head_B
        epoch_loss = config.epoch_loss_head_B
        epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
        lamb = config.lamb_B

      iterators = (d for d in dataloaders)
      b_i = 0
      avg_loss = 0.  # over heads and head_epochs (and sub_heads)
      avg_loss_no_lamb = 0.
      avg_loss_count = 0

      for tup in itertools.izip(*iterators):
        net.module.zero_grad()

        if not config.no_sobel:
          pre_channels = config.in_channels - 1
        else:
          pre_channels = config.in_channels

        all_img1 = torch.zeros(config.batch_sz, pre_channels,
                               config.input_sz, config.input_sz).to(
          torch.float32).cuda()
        all_img2 = torch.zeros(config.batch_sz, pre_channels,
                               config.input_sz, config.input_sz).to(
          torch.float32).cuda()
        all_affine2_to_1 = torch.zeros(config.batch_sz, 2, 3).to(
          torch.float32).cuda()
        all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz,
                                    config.input_sz).to(torch.float32).cuda()

        curr_batch_sz = tup[0][0].shape[0]
        for d_i in xrange(config.num_dataloaders):
          img1, img2, affine2_to_1, mask_img1 = tup[d_i]
          assert (img1.shape[0] == curr_batch_sz)

          actual_batch_start = d_i * curr_batch_sz
          actual_batch_end = actual_batch_start + curr_batch_sz

          all_img1[actual_batch_start:actual_batch_end, :, :, :] = img1
          all_img2[actual_batch_start:actual_batch_end, :, :, :] = img2
          all_affine2_to_1[actual_batch_start:actual_batch_end, :,
          :] = affine2_to_1
          all_mask_img1[actual_batch_start:actual_batch_end, :, :] = mask_img1

        if not (curr_batch_sz == config.dataloader_batch_sz) and (
            e_i == next_epoch):
          print("last batch sz %d" % curr_batch_sz)

        curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  # times 2
        all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
        all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
        all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
        all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

        if (not config.no_sobel):
          all_img1 = sobel_process(all_img1, config.include_rgb,
                                   using_IR=config.using_IR)
          all_img2 = sobel_process(all_img2, config.include_rgb,
                                   using_IR=config.using_IR)

        x1_outs = net(all_img1, head=head)
        x2_outs = net(all_img2, head=head)

        avg_loss_batch = None  # avg over the heads
        avg_loss_no_lamb_batch = None

        for i in xrange(config.num_sub_heads):
          loss, loss_no_lamb = loss_fn(x1_outs[i],
                                       x2_outs[i],
                                       all_affine2_to_1=all_affine2_to_1,
                                       all_mask_img1=all_mask_img1,
                                       lamb=lamb,
                                       half_T_side_dense=config.half_T_side_dense,
                                       half_T_side_sparse_min=config.half_T_side_sparse_min,
                                       half_T_side_sparse_max=config.half_T_side_sparse_max)

          if avg_loss_batch is None:
            avg_loss_batch = loss
            avg_loss_no_lamb_batch = loss_no_lamb
          else:
            avg_loss_batch += loss
            avg_loss_no_lamb_batch += loss_no_lamb

        avg_loss_batch /= config.num_sub_heads
        avg_loss_no_lamb_batch /= config.num_sub_heads

        if ((b_i % 100) == 0) or (e_i == next_epoch):
          print(
            "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
            "lamb %f "
            "time %s" % \
            (config.model_ind, e_i, head, b_i, avg_loss_batch.item(),
             avg_loss_no_lamb_batch.item(), datetime.now()))
          sys.stdout.flush()

        if not np.isfinite(avg_loss_batch.item()):
          print("Loss is not finite... %s:" % str(avg_loss_batch))
          exit(1)

        avg_loss += avg_loss_batch.item()
        avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
        avg_loss_count += 1

        avg_loss_batch.backward()
        optimiser.step()

        torch.cuda.empty_cache()

        b_i += 1
        if b_i == 2 and config.test_code:
          break

      avg_loss = float(avg_loss / avg_loss_count)
      avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

      epoch_loss.append(avg_loss)
      epoch_loss_no_lamb.append(avg_loss_no_lamb)

    # Eval
    # -----------------------------------------------------------------------

    is_best = segmentation_eval(config, net,
                                mapping_assignment_dataloader=mapping_assignment_dataloader,
                                mapping_test_dataloader=mapping_test_dataloader,
                                sobel=(
                                  not config.no_sobel),
                                using_IR=config.using_IR)

    print(
      "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()

    axarr[0].clear()
    axarr[0].plot(config.epoch_acc)
    axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

    axarr[1].clear()
    axarr[1].plot(config.epoch_avg_subhead_acc)
    axarr[1].set_title("acc (avg), top: %f" % max(config.epoch_avg_subhead_acc))

    axarr[2].clear()
    axarr[2].plot(config.epoch_loss_head_A)
    axarr[2].set_title("Loss head A")

    axarr[3].clear()
    axarr[3].plot(config.epoch_loss_no_lamb_head_A)
    axarr[3].set_title("Loss no lamb head A")

    axarr[4].clear()
    axarr[4].plot(config.epoch_loss_head_B)
    axarr[4].set_title("Loss head B")

    axarr[5].clear()
    axarr[5].plot(config.epoch_loss_no_lamb_head_B)
    axarr[5].set_title("Loss no lamb head B")

    fig.canvas.draw_idle()
    fig.savefig(os.path.join(config.out_dir, "plots.png"))

    if is_best or (e_i % config.save_freq == 0):
      net.module.cpu()
      save_dict = {"net": net.module.state_dict(),
                   "optimiser": optimiser.state_dict()}

      if e_i % config.save_freq == 0:
        torch.save(save_dict, os.path.join(config.out_dir, "latest.pytorch"))
        config.last_epoch = e_i  # for last saved version

      if is_best:
        torch.save(save_dict, os.path.join(config.out_dir, "best.pytorch"))

        with open(os.path.join(config.out_dir, "best_config.pickle"),
                  'wb') as outfile:
          pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "best_config.txt"),
                  "w") as text_file:
          text_file.write("%s" % config)

      net.module.cuda()

    with open(os.path.join(config.out_dir, "config.pickle"), 'wb') as outfile:
      pickle.dump(config, outfile)

    with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file:
      text_file.write("%s" % config)

    if config.test_code:
      exit(0)
Exemple #10
0
def get_subhead_using_loss(config,
                           dataloaders_head_B,
                           net,
                           sobel,
                           lamb,
                           compare=False):
    net.eval()

    head = "B"  # main output head
    dataloaders = dataloaders_head_B
    iterators = (d for d in dataloaders)

    b_i = 0
    loss_per_sub_head = np.zeros(config.num_sub_heads)
    for tup in itertools.izip(*iterators):
        net.module.zero_grad()

        dim = config.in_channels
        if sobel:
            dim -= 1

        all_imgs = torch.zeros(config.batch_sz, dim, config.input_sz,
                               config.input_sz).cuda()
        all_imgs_tf = torch.zeros(config.batch_sz, dim, config.input_sz,
                                  config.input_sz).cuda()

        imgs_curr = tup[0][0]  # always the first
        curr_batch_sz = imgs_curr.size(0)
        for d_i in xrange(config.num_dataloaders):
            imgs_tf_curr = tup[1 + d_i][0]  # from 2nd to last
            assert (curr_batch_sz == imgs_tf_curr.size(0))

            actual_batch_start = d_i * curr_batch_sz
            actual_batch_end = actual_batch_start + curr_batch_sz
            all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \
              imgs_curr.cuda()
            all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \
              imgs_tf_curr.cuda()

        curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
        all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
        all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

        if sobel:
            all_imgs = sobel_process(all_imgs, config.include_rgb)
            all_imgs_tf = sobel_process(all_imgs_tf, config.include_rgb)

        with torch.no_grad():
            x_outs = net(all_imgs, head=head)
            x_tf_outs = net(all_imgs_tf, head=head)

        for i in xrange(config.num_sub_heads):
            loss, loss_no_lamb = IID_loss(x_outs[i], x_tf_outs[i], lamb=lamb)
            loss_per_sub_head[i] += loss.item()

        if b_i % 100 == 0:
            print("at batch %d" % b_i)
            sys.stdout.flush()
        b_i += 1

    best_sub_head_loss = np.argmin(loss_per_sub_head)

    if compare:
        print(loss_per_sub_head)
        print("best sub_head by loss: %d" % best_sub_head_loss)

        best_epoch = np.argmax(np.array(config.epoch_acc))
        if "best_train_sub_head" in config.epoch_stats[best_epoch]:
            best_sub_head_eval = config.epoch_stats[best_epoch][
                "best_train_sub_head"]
            test_accs = config.epoch_stats[best_epoch]["test_accs"]
        else:  # older config version
            best_sub_head_eval = config.epoch_stats[best_epoch]["best_head"]
            test_accs = config.epoch_stats[best_epoch]["all"]

        print("best sub_head by eval: %d" % best_sub_head_eval)

        print("... loss select acc: %f, eval select acc: %f" %
              (test_accs[best_sub_head_loss], test_accs[best_sub_head_eval]))

    net.train()

    return best_sub_head_loss
Exemple #11
0
def _clustering_get_data(config,
                         net,
                         dataloader,
                         sobel=False,
                         using_IR=False,
                         get_soft=False,
                         verbose=0):
    """
  Returns cuda tensors for flat preds and targets.
  """

    assert (not using_IR)  # sanity; IR used by segmentation only

    num_batches = len(dataloader)
    flat_targets_all = torch.zeros((num_batches * config.batch_sz),
                                   dtype=torch.int32).cuda()
    flat_predss_all = [
        torch.zeros((num_batches * config.batch_sz), dtype=torch.int32).cuda()
        for _ in xrange(config.num_sub_heads)
    ]

    if get_soft:
        soft_predss_all = [
            torch.zeros((num_batches * config.batch_sz, config.output_k),
                        dtype=torch.float32).cuda()
            for _ in xrange(config.num_sub_heads)
        ]

    num_test = 0
    for b_i, batch in enumerate(dataloader):
        imgs = batch[0].cuda()

        if sobel:
            imgs = sobel_process(imgs, config.include_rgb, using_IR=using_IR)

        flat_targets = batch[1]

        with torch.no_grad():
            x_outs = net(imgs)

        assert (x_outs[0].shape[1] == config.output_k)
        assert (len(x_outs[0].shape) == 2)

        num_test_curr = flat_targets.shape[0]
        num_test += num_test_curr

        start_i = b_i * config.batch_sz
        for i in xrange(config.num_sub_heads):
            x_outs_curr = x_outs[i]
            flat_preds_curr = torch.argmax(x_outs_curr,
                                           dim=1)  # along output_k
            flat_predss_all[i][start_i:(start_i +
                                        num_test_curr)] = flat_preds_curr

            if get_soft:
                soft_predss_all[i][start_i:(start_i +
                                            num_test_curr), :] = x_outs_curr

        flat_targets_all[start_i:(start_i +
                                  num_test_curr)] = flat_targets.cuda()

    flat_predss_all = [
        flat_predss_all[i][:num_test] for i in xrange(config.num_sub_heads)
    ]
    flat_targets_all = flat_targets_all[:num_test]

    if not get_soft:
        return flat_predss_all, flat_targets_all
    else:
        soft_predss_all = [
            soft_predss_all[i][:num_test] for i in xrange(config.num_sub_heads)
        ]

        return flat_predss_all, flat_targets_all, soft_predss_all