def get_code_for_data_three(model, data, opt): options = opt['train'] dci_num_comp_indices = int(options['dci_num_comp_indices']) dci_num_simp_indices = int(options['dci_num_simp_indices']) dci_num_levels = int(options['dci_num_levels']) dci_construction_field_of_view = int(options['dci_construction_field_of_view']) dci_query_field_of_view = int(options['dci_query_field_of_view']) dci_prop_to_visit = float(options['dci_prop_to_visit']) dci_prop_to_retrieve = float(options['dci_prop_to_retrieve']) sample_perturbation_magnitude = float(options['sample_perturbation_magnitude']) code_nc = int(opt['network_G']['in_code_nc']) pull_num_sample_per_img = int(options['num_code_per_img']) pull_gen_img = data['LR'] d1_gen_img = data['D1'] d2_gen_img = data['D2'] real_gen_img = data['HR'] pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) forward_bs = 10 print("Generating Pull Samples") data_length = pull_gen_img.shape[0] out_feature_shape = model.netF(data['D1'][:1])[-1].shape[1:] # ============ ADD POINT ================== pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding first stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat_D1'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat_D1'] target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])).double().numpy().copy() pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding first stack code: Processed %d out of %d instances' % ( data_length, data_length)) # ============ ADD POINT ================== pull_gen_code_1 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) out_feature_shape = model.netF(data['D2'][:1])[-1].shape[1:] pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding second stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) # ============ ADD POINT ================== pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat_D2'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat_D2'].double().numpy().copy() target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_1[sample_index, :] = pull_gen_code_pool_1[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding second stack code: Processed %d out of %d instances' % ( data_length, data_length)) # ============ ADD POINT ================== pull_gen_code_2 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) out_feature_shape = model.netF(data['HR'][:1])[-1].shape[1:] pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding third stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) # ============ ADD POINT ================== pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_1 = pull_gen_code_1[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat'].double().numpy().copy() target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_2[sample_index, :] = pull_gen_code_pool_2[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding third stack code: Processed %d out of %d instances' % ( data_length, data_length)) pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) pull_gen_code_1 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_2 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) return [pull_gen_code_0, pull_gen_code_1, pull_gen_code_2]
def get_code_for_data(model, data, opt): options = opt['train'] dci_num_comp_indices = int(options['dci_num_comp_indices']) dci_num_simp_indices = int(options['dci_num_simp_indices']) dci_num_levels = int(options['dci_num_levels']) dci_construction_field_of_view = int(options['dci_construction_field_of_view']) dci_query_field_of_view = int(options['dci_query_field_of_view']) dci_prop_to_visit = float(options['dci_prop_to_visit']) dci_prop_to_retrieve = float(options['dci_prop_to_retrieve']) sample_perturbation_magnitude = float(options['sample_perturbation_magnitude']) code_nc = int(opt['network_G']['in_code_nc']) pull_num_sample_per_img = int(options['num_code_per_img']) show_message = False if 'show_message' not in options else options['show_message'] pull_gen_img = data['LR'] real_gen_img = data['HR'] pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) if show_message: print("Generating Pull Samples") data_length = pull_gen_img.shape[0] out_feature_shape = model.netF(data['HR'][:1]).shape[1:] # initialize dci db pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0 and show_message: print_without_newline( '\rFinding first stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) if 'zero_code' in options and options['zero_code']: pull_gen_code_pool_0 = torch.zeros(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) elif 'rand_code' in options and options['rand_code']: pull_gen_code_pool_0 = torch.rand(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) else: pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) pull_img = pull_gen_img[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) # cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1]} cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1].expand( (pull_num_sample_per_img,) + real_gen_img.shape[1:])} model.feed_data(cur_data, code=pull_gen_code_pool_0) feature_output = model.get_features() pull_gen_features_pool = feature_output['gen_feat'] target_feature = feature_output['gen_feat'] pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod( pull_gen_features_pool.shape[1:])).double().numpy().copy() target_feature = target_feature.reshape(-1, np.prod(target_feature.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool, num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature.numpy(), num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :] # clear the db for next query pull_samples_dci_db.clear() if show_message: print('\rFinding first stack code: Processed %d out of %d instances' % ( data_length, data_length)) if 'zero_code' in options and options['zero_code']: pull_gen_code_0 += sample_perturbation_magnitude * torch.zeros(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) elif 'rand_code' in options and options['rand_code']: pull_gen_code_0 += sample_perturbation_magnitude * torch.rand(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) else: pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) return pull_gen_code_0
def runknn(sz=4, num_points=202599, num_queries=1, num_comp_indices=2, num_simp_indices=7, num_outer_iterations=202599, max_num_candidates=5000, num_neighbours=10, patch_size=5): num_levels = 3 construction_field_of_view = 10 construction_prop_to_retrieve = 0.02 query_field_of_view = 12000 query_prop_to_retrieve = 1.0 dim = patch_size * patch_size * 5 #setmodulename('\"../church_npy/church_%d.npy\"'%sz) #data_and_queries = get_img(dim, num_points + num_queries) #data = np.load("../church_npy/church_%d.npy"%sz,mmap_mode='r')#get_img(dim, num_points)#np.copy(data_and_queries[:num_points,:]) down = int(math.log(256 / sz, 2)) dci_db = DCI(dim, num_comp_indices, num_simp_indices) st = time.time() print("before adding") dci_db.add(num_points, "../church_npy/church_%d_vgg_12_5.npy" % sz, num_levels=num_levels, field_of_view=construction_field_of_view, prop_to_retrieve=construction_prop_to_retrieve, load_from_file=1) print("construction time:", time.time() - st) imgid = 270 flip = 0 datamat = np.load("../church_npy/church_%d.npy" % 128, mmap_mode='r') rawimg = imread("../results_churchoutdoor/fake_%04d.png" % imgid) if flip: rawimg = np.flip(rawimg, 1) #im = Image.fromarray(target, 'RGB') for j in range(1): rawimg = np.mean(np.concatenate([ rawimg[0::2, 0::2, None], rawimg[0::2, 1::2, None], rawimg[1::2, 0::2, None], rawimg[1::2, 1::2, None] ], axis=2), axis=2) rawimg = rawimg.astype(np.uint8) if flip: fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12_flip.npy" % sz, mmap_mode='r') else: fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12.npy" % sz, mmap_mode='r') mularr = np.load("../church_npy/rpvec_64_5.npy") fakeimgs_rp = fakeimgs[imgid].copy() print(fakeimgs_rp.shape) fakeimg_rp = np.dot(fakeimgs_rp, mularr) del fakeimgs del mularr minx = 0 miny = 0 maxx = 128 maxy = 128 ambient_dim = 32 * 32 * 5 numx = (max(minx + 1, maxx - 32 + 1) - minx + 32 - 1) / 32 numy = (max(miny + 1, maxy - 32 + 1) - miny + 32 - 1) / 32 print(numx, numy) queries = np.empty([25 * 25, ambient_dim], dtype=np.float32) d = 32 st = 0 for j in range(miny, max(miny + 1, maxy - 32 + 1), 4): for k in range(minx, max(minx + 1, maxx - 32 + 1), 4): eyeimg = fakeimg_rp[j:j + d, k:k + d, :].flatten().astype(np.float32) queries[st] = eyeimg / np.linalg.norm(eyeimg) * 255 st += 1 print(queries.shape, queries.dtype, st) st = time.time() num_neighbours = 200 query_field_of_view = 10000 #11000 query_prop_to_retrieve = 1.0 nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query( queries, num_neighbours, field_of_view=query_field_of_view, prop_to_retrieve=query_prop_to_retrieve, blind=False) print("query time:", time.time() - st) finaldist = np.array(nearest_neighbour_dists) rawidx = np.array(nearest_neighbour_idx) print(rawidx.shape) finalidx = rawidx / (25 * 25) finalpos = np.empty([rawidx.shape[0], 200, 2], dtype=np.int) for i in range(rawidx.shape[0]): for j in range(200): offset = rawidx[i, j] % (25 * 25) finalpos[i, j, 0] = offset / 25 * 4 finalpos[i, j, 1] = offset % 25 * 4 np.savez("fake_%04d_all_overlap.npz" % imgid, finalidx=finalidx, finaldist=finaldist, finalpos=finalpos) return queries = get_query(dim, 1, down, sz, patch_size, patch_size) #data_and_queries[num_points:,:] print(queries.shape, queries.dtype) queries = queries[0:2] st = time.time() nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query( queries, num_neighbours, field_of_view=query_field_of_view, prop_to_retrieve=query_prop_to_retrieve, blind=False) print("query time:", time.time() - st) #gen_res(num_points, queries, nearest_neighbour_idx, down, sz) #resfile = "../save_%d/res.txt"%(sz) #f = open(resfile,'w') print(np.array(nearest_neighbour_idx)[:, 0]) print(np.array(nearest_neighbour_dists)[:, 0]) np.save("churchidx_18_v7.npy", np.array(nearest_neighbour_idx)) np.save("churchdist_18_v7.npy", np.array(nearest_neighbour_dists)) dci_db.clear()