예제 #1
0
def get_code_for_data_three(model, data, opt):
    options = opt['train']

    dci_num_comp_indices = int(options['dci_num_comp_indices'])
    dci_num_simp_indices = int(options['dci_num_simp_indices'])
    dci_num_levels = int(options['dci_num_levels'])
    dci_construction_field_of_view = int(options['dci_construction_field_of_view'])
    dci_query_field_of_view = int(options['dci_query_field_of_view'])
    dci_prop_to_visit = float(options['dci_prop_to_visit'])
    dci_prop_to_retrieve = float(options['dci_prop_to_retrieve'])
    sample_perturbation_magnitude = float(options['sample_perturbation_magnitude'])

    code_nc = int(opt['network_G']['in_code_nc'])
    pull_num_sample_per_img = int(options['num_code_per_img'])

    pull_gen_img = data['LR']
    d1_gen_img = data['D1']
    d2_gen_img = data['D2']
    real_gen_img = data['HR']

    pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2)

    forward_bs = 10

    print("Generating Pull Samples")
    data_length = pull_gen_img.shape[0]

    out_feature_shape = model.netF(data['D1'][:1])[-1].shape[1:]

    # ============ ADD POINT ==================
    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding first stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2)
        pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4,
                                           pull_gen_img.shape[3] * 4)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat_D1'].double().numpy())
        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat_D1']
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])).double().numpy().copy()

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding first stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    # ============ ADD POINT ==================
    pull_gen_code_1 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4)
    out_feature_shape = model.netF(data['D2'][:1])[-1].shape[1:]

    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding second stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        # ============ ADD POINT ==================
        pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4,
                                           pull_gen_img.shape[3] * 4)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat_D2'].double().numpy())

        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat_D2'].double().numpy().copy()
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:]))

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_1[sample_index, :] = pull_gen_code_pool_1[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding second stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    # ============ ADD POINT ==================
    pull_gen_code_2 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8)
    out_feature_shape = model.netF(data['HR'][:1])[-1].shape[1:]

    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding third stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        # ============ ADD POINT ==================
        pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_1 = pull_gen_code_1[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat'].double().numpy())

        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat'].double().numpy().copy()
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:]))

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_2[sample_index, :] = pull_gen_code_pool_2[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding third stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 2,
                                                                   pull_gen_img.shape[3] * 2)
    pull_gen_code_1 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 4,
                                                                   pull_gen_img.shape[3] * 4)
    pull_gen_code_2 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 8,
                                                                   pull_gen_img.shape[3] * 8)

    return [pull_gen_code_0, pull_gen_code_1, pull_gen_code_2]
예제 #2
0
def get_code_for_data(model, data, opt):
    options = opt['train']

    dci_num_comp_indices = int(options['dci_num_comp_indices'])
    dci_num_simp_indices = int(options['dci_num_simp_indices'])
    dci_num_levels = int(options['dci_num_levels'])
    dci_construction_field_of_view = int(options['dci_construction_field_of_view'])
    dci_query_field_of_view = int(options['dci_query_field_of_view'])
    dci_prop_to_visit = float(options['dci_prop_to_visit'])
    dci_prop_to_retrieve = float(options['dci_prop_to_retrieve'])
    sample_perturbation_magnitude = float(options['sample_perturbation_magnitude'])

    code_nc = int(opt['network_G']['in_code_nc'])
    pull_num_sample_per_img = int(options['num_code_per_img'])

    show_message = False if 'show_message' not in options else options['show_message']

    pull_gen_img = data['LR']
    real_gen_img = data['HR']
    pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3])

    if show_message:
        print("Generating Pull Samples")
    data_length = pull_gen_img.shape[0]

    out_feature_shape = model.netF(data['HR'][:1]).shape[1:]
    # initialize dci db
    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0 and show_message:
            print_without_newline(
                '\rFinding first stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        if 'zero_code' in options and options['zero_code']:
            pull_gen_code_pool_0 = torch.zeros(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])
        elif 'rand_code' in options and options['rand_code']:
            pull_gen_code_pool_0 = torch.rand(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])
        else:
            pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])

        pull_img = pull_gen_img[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)

        # cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1]}
        cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1].expand(
            (pull_num_sample_per_img,) + real_gen_img.shape[1:])}

        model.feed_data(cur_data, code=pull_gen_code_pool_0)

        feature_output = model.get_features()

        pull_gen_features_pool = feature_output['gen_feat']
        target_feature = feature_output['gen_feat']

        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(
            pull_gen_features_pool.shape[1:])).double().numpy().copy()
        target_feature = target_feature.reshape(-1, np.prod(target_feature.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool,
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature.numpy(),
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db for next query
        pull_samples_dci_db.clear()

    if show_message:
        print('\rFinding first stack code: Processed %d out of %d instances' % (
            data_length, data_length))

    if 'zero_code' in options and options['zero_code']:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.zeros(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])
    elif 'rand_code' in options and options['rand_code']:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.rand(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])
    else:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])

    return pull_gen_code_0
def runknn(sz=4,
           num_points=202599,
           num_queries=1,
           num_comp_indices=2,
           num_simp_indices=7,
           num_outer_iterations=202599,
           max_num_candidates=5000,
           num_neighbours=10,
           patch_size=5):
    num_levels = 3
    construction_field_of_view = 10
    construction_prop_to_retrieve = 0.02
    query_field_of_view = 12000
    query_prop_to_retrieve = 1.0
    dim = patch_size * patch_size * 5
    #setmodulename('\"../church_npy/church_%d.npy\"'%sz)

    #data_and_queries = get_img(dim, num_points + num_queries)
    #data = np.load("../church_npy/church_%d.npy"%sz,mmap_mode='r')#get_img(dim, num_points)#np.copy(data_and_queries[:num_points,:])
    down = int(math.log(256 / sz, 2))

    dci_db = DCI(dim, num_comp_indices, num_simp_indices)
    st = time.time()
    print("before adding")
    dci_db.add(num_points,
               "../church_npy/church_%d_vgg_12_5.npy" % sz,
               num_levels=num_levels,
               field_of_view=construction_field_of_view,
               prop_to_retrieve=construction_prop_to_retrieve,
               load_from_file=1)
    print("construction time:", time.time() - st)
    imgid = 270
    flip = 0
    datamat = np.load("../church_npy/church_%d.npy" % 128, mmap_mode='r')
    rawimg = imread("../results_churchoutdoor/fake_%04d.png" % imgid)

    if flip:
        rawimg = np.flip(rawimg, 1)
    #im = Image.fromarray(target, 'RGB')
    for j in range(1):
        rawimg = np.mean(np.concatenate([
            rawimg[0::2, 0::2, None], rawimg[0::2, 1::2, None],
            rawimg[1::2, 0::2, None], rawimg[1::2, 1::2, None]
        ],
                                        axis=2),
                         axis=2)
    rawimg = rawimg.astype(np.uint8)
    if flip:
        fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12_flip.npy" % sz,
                           mmap_mode='r')
    else:
        fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12.npy" % sz,
                           mmap_mode='r')
    mularr = np.load("../church_npy/rpvec_64_5.npy")
    fakeimgs_rp = fakeimgs[imgid].copy()
    print(fakeimgs_rp.shape)
    fakeimg_rp = np.dot(fakeimgs_rp, mularr)
    del fakeimgs
    del mularr
    minx = 0
    miny = 0
    maxx = 128
    maxy = 128
    ambient_dim = 32 * 32 * 5
    numx = (max(minx + 1, maxx - 32 + 1) - minx + 32 - 1) / 32
    numy = (max(miny + 1, maxy - 32 + 1) - miny + 32 - 1) / 32
    print(numx, numy)
    queries = np.empty([25 * 25, ambient_dim], dtype=np.float32)
    d = 32
    st = 0
    for j in range(miny, max(miny + 1, maxy - 32 + 1), 4):
        for k in range(minx, max(minx + 1, maxx - 32 + 1), 4):
            eyeimg = fakeimg_rp[j:j + d,
                                k:k + d, :].flatten().astype(np.float32)
            queries[st] = eyeimg / np.linalg.norm(eyeimg) * 255
            st += 1
    print(queries.shape, queries.dtype, st)
    st = time.time()
    num_neighbours = 200
    query_field_of_view = 10000  #11000
    query_prop_to_retrieve = 1.0
    nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query(
        queries,
        num_neighbours,
        field_of_view=query_field_of_view,
        prop_to_retrieve=query_prop_to_retrieve,
        blind=False)
    print("query time:", time.time() - st)
    finaldist = np.array(nearest_neighbour_dists)
    rawidx = np.array(nearest_neighbour_idx)
    print(rawidx.shape)
    finalidx = rawidx / (25 * 25)
    finalpos = np.empty([rawidx.shape[0], 200, 2], dtype=np.int)
    for i in range(rawidx.shape[0]):
        for j in range(200):
            offset = rawidx[i, j] % (25 * 25)
            finalpos[i, j, 0] = offset / 25 * 4
            finalpos[i, j, 1] = offset % 25 * 4
    np.savez("fake_%04d_all_overlap.npz" % imgid,
             finalidx=finalidx,
             finaldist=finaldist,
             finalpos=finalpos)

    return
    queries = get_query(dim, 1, down, sz, patch_size,
                        patch_size)  #data_and_queries[num_points:,:]
    print(queries.shape, queries.dtype)
    queries = queries[0:2]
    st = time.time()
    nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query(
        queries,
        num_neighbours,
        field_of_view=query_field_of_view,
        prop_to_retrieve=query_prop_to_retrieve,
        blind=False)
    print("query time:", time.time() - st)
    #gen_res(num_points, queries, nearest_neighbour_idx, down, sz)
    #resfile = "../save_%d/res.txt"%(sz)
    #f = open(resfile,'w')
    print(np.array(nearest_neighbour_idx)[:, 0])
    print(np.array(nearest_neighbour_dists)[:, 0])
    np.save("churchidx_18_v7.npy", np.array(nearest_neighbour_idx))
    np.save("churchdist_18_v7.npy", np.array(nearest_neighbour_dists))
    dci_db.clear()