Beispiel #1
0
def correlate(input1, input2, args):
    out_corr = spatial_correlation_sample(input1, input2, **args)
    # collate dimensions 1 and 2 in order to be treated as a
    # regular 4D tensor
    b, ph, pw, h, w = out_corr.size()
    out_corr = out_corr.view(b, ph * pw, h, w) / input1.size(1)
    return F.leaky_relu_(out_corr, 0.1)
Beispiel #2
0
def correlate(input1, input2):
    out_corr = spatial_correlation_sample(input1,
                                          input2,
                                          kernel_size=1,
                                          patch_size=9,
                                          stride=1)
    # collate dimensions 1 and 2 in order to be treated as a
    # regular 4D tensor
    b, ph, pw, h, w = out_corr.size()
    out_corr = out_corr.view(b, ph * pw, h, w) / input1.size(1)
    return out_corr
Beispiel #3
0
def correlate(input1, input2):
    out_corr = spatial_correlation_sample(input1,
                                          input2,
                                          kernel_size=1,
                                          patch_size=21,
                                          stride=1,
                                          padding=0,
                                          dilation_patch=2)
    # collate dimensions 1 and 2 in order to be treated as a
    # regular 4D tensor
    b, ph, pw, h, w = out_corr.size()
    out_corr = out_corr.view(b, ph * pw, h, w) / input1.size(1)
    return F.leaky_relu_(out_corr, 0.1)
Beispiel #4
0
        def restricted(k, q):
            _, N, T, C, H, W = k.shape

            A = spatial_correlation_sample(
                    k.view(k.shape[1]*T, C, H, W),
                    q.view(q.shape[1]*T, C, H, W),
                    patch_size=int(2*args.radius+1))

            A = A.view(1, _q.shape[1], T, *A.shape[-4:]) /args.temperature
            # A[A==0] = -1e20  # ignored idxs in softmax

            _, N, T, H1, W1, H, W = A.shape
            A = A.view(N, T*H1*W1, H, W)
            weights, ids = torch.topk(A, topk_vis, dim=-3)
            weights = torch.nn.functional.softmax(weights, dim=-3)

            return A, weights, ids
def correlate(x1, x2, patch_size=11, dilation_patch=1):
    """
    :param x1: features 1
    :param x2: features 2
    :param patch_size: the size of whole patch is used to calculate the correlation
    :return:
    """

    # Output sizes oH and oW are no longer dependant of patch size, but only of kernel size and padding
    # patch_size is now the whole patch, and not only the radii.
    # stride1 is now stride and stride2 is dilation_patch, which behave like dilated convolutions
    # equivalent max_displacement is then dilation_patch * (patch_size - 1) / 2.
    # to get the right parameters for FlowNetC, you would have
    out_corr = spatial_correlation_sample(x1,
                                          x2,
                                          kernel_size=1,
                                          patch_size=patch_size,
                                          stride=1,
                                          padding=0,
                                          dilation_patch=dilation_patch)
    b, ph, pw, h, w = out_corr.size()
    out_corr = out_corr.view(b, ph * pw, h, w) / x1.size(1)
    return F.leaky_relu_(out_corr, 0.1)
Beispiel #6
0
def test(val_loader, model, epoch, use_cuda):

    save_path = args.save_path + '/'
    save_file = '%s/list.txt' % save_path
    fileout = open(save_file, 'w')

    end = time.time()

    job_args = []

    n_context = params['videoLen']
    topk_vis = args.topk_vis

    t_vid = 0

    for batch_idx, (imgs_total, imgs_orig, lbl_set, lbls_tensor, lbls_onehot,
                    lbls_resize, meta) in enumerate(val_loader):
        t_vid = time.time()
        print('******* Vid %s *******' % batch_idx)

        # measure data loading time
        imgs_total = imgs_total.cuda()
        bs, total_frame_num, channel_num, height_len, weight_len = imgs_total.shape

        assert (bs == 1)

        folder_paths = meta['folder_path']
        print('total_frame_num: ' + str(total_frame_num))

        ##################################################################
        # Print the images
        ##################################################################

        imgs_set = imgs_total.data
        imgs_set = imgs_set.cpu().numpy()
        imgs_set = imgs_set[0]

        imgs_toprint = [ii for ii in imgs_orig[0]]

        # ref image
        t00 = time.time()

        # for t in range(imgs_orig.shape[0]):
        #     img_now = imgs_orig[t]
        #     img_now = np.transpose(img_now, (1, 2, 0))
        #     img_now = cv2.resize(img_now, (img_now.shape[0] * 2, img_now.shape[1] * 2) )
        #     imgs_toprint.append(img_now)

        #     imname  = save_path + str(batch_idx) + '_' + str(t) + '_frame.jpg'
        #     imageio.imwrite(imname, img_now.astype(np.uint8))

        # print('printed images', time.time()-t00)

        ##################################################################
        # Compute image features
        ##################################################################

        t00 = time.time()

        feats = []
        bsize = 5
        for b in range(0, imgs_total.shape[1], bsize):
            torch.cuda.empty_cache()
            node, feat = model.module(imgs_total[:, b:b + bsize],
                                      None,
                                      True,
                                      func='forward')
            feats.append(feat.cpu())
        feats = torch.cat(feats, dim=2)

        feats = feats.detach().squeeze(1)
        feats = torch.nn.functional.normalize(feats, dim=1)

        print('computed features', time.time() - t00)

        ##################################################################
        # Prep labels
        ##################################################################

        for t in range(n_context):
            imname = save_path + str(batch_idx) + '_' + str(t) + '_label.jpg'
            imageio.imwrite(imname, lbls_tensor[0][t].numpy().astype(np.uint8))
        # print('wrote frames and labels')

        ##################################################################
        # Compute correlation features
        ##################################################################

        imgs_stack = []
        im_num = total_frame_num - n_context
        t03 = time.time()

        indices = torch.cat([
            torch.zeros(im_num, 1).long(),
            (torch.arange(n_context)[None].repeat(im_num, 1) +
             torch.arange(im_num)[:, None])[:, 1:]
        ],
                            dim=-1)

        feats = feats.cpu()

        if isinstance(lbl_set, list):
            lbl_set = torch.cat(lbl_set)[None]
        lbls_resize[0, n_context * 2 - 1:] *= 0

        # H x W x L -> L x H x W
        lbls_resize = lbls_resize.transpose(-1, -3).transpose(-1, -2)

        As, Ws, Is = [], [], []
        keys, query = feats[:, :, indices], feats[:, :, n_context:]

        _, C, N, T, H, W = keys.shape
        # for ease with spatial_correlation_sampler
        keys = keys.permute(0, 2, 3, 1, 4,
                            5)  # reshape to 1 x N x T X C X H X W
        query = query.permute(0, 2, 1, 3, 4).unsqueeze(2).expand_as(keys)

        q_dim = 2 if args.all_nn else 3
        bsize = 2

        for b in range(0, keys.shape[1], bsize):
            _k, _q = keys[:, b:b + bsize].cuda(), query[:, b:b + bsize].cuda()

            A = spatial_correlation_sample(_q.view(_q.shape[1] * T, C, H, W),
                                           _k.view(_k.shape[1] * T, C, H, W),
                                           patch_size=int(2 * args.radius + 1))

            A = A.view(1, _q.shape[1], T, *A.shape[-4:])
            A /= args.temperature
            # A[A==0] = -1e20  # ignored idxs in softmax

            _, N, T, H1, W1, H, W = A.shape
            A1 = A.view(N, T * H1 * W1, H, W)
            weights, ids = torch.topk(A1, topk_vis, dim=-3)
            weights = torch.nn.functional.softmax(weights, dim=-3)

            # As += [a for a in A.cpu()[0]]
            Ws += [w for w in weights.cpu()]
            Is += [ii for ii in ids.cpu()]

        # As, Ws, Is = (torch.cat(_, dim=1) for _ in (As, Ws, Is))

        t04 = time.time()
        print(t04 - t03, 'computed affinities',
              torch.cuda.max_memory_allocated() / (1024**2))

        # As, Ws, Is = As[0], Ws[0], Is[0]
        lbl_set, lbls_resize = lbl_set[0], lbls_resize[0]

        ##################################################################
        # Label propagation
        ##################################################################

        L, H, W = lbls_resize.shape[1:]
        lbls_idx = torch.arange(T * H * W).view(T, H, W)
        lbls_idx = F.pad(lbls_idx, [int(args.radius)] * 4, 'constant', -1)
        lbls_idx = F.unfold(lbls_idx[None].float(),
                            kernel_size=int(args.radius * 2 + 1)).view(
                                1, -1, H, W).long().cuda()

        nstep = len(imgs_toprint) - n_context

        for it in range(nstep):
            if it % 10 == 0:
                print(it, torch.cuda.max_memory_allocated() / (1024**2))

            weights, idxs = Ws[it].cuda(), Is[it].cuda()
            lbls_base = lbls_resize[indices[it]].cuda()
            t06 = time.time()

            # indexing based
            flat_lbls = lbls_base.transpose(0, 1).flatten(1)

            global_idxs = torch.gather(lbls_idx, 1, idxs[None])
            nn_lbls = flat_lbls[:,
                                global_idxs.view(topk_vis, -1).t()].transpose(
                                    -1, -2)

            predlbls = (nn_lbls.view(L, topk_vis, H, W) * weights[None]).sum(1)

            # hard prop
            # predlbls = hard_prop(predlbls)

            img_now = imgs_toprint[it + n_context].permute(1, 2,
                                                           0).numpy() * 255

            if it > 0:
                lbls_resize[it + n_context] = predlbls
            else:
                predlbls = lbls_resize[0]

            # Save Predictions
            dump_predictions(predlbls.cpu().permute(1, 2, 0).numpy(), lbl_set,
                             img_now,
                             save_path + str(batch_idx) + '_' + str(it))

        torch.cuda.empty_cache()

        print('******* Vid %s TOOK %s *******' %
              (batch_idx, time.time() - t_vid))