Ejemplo n.º 1
0
def main(*args):

    dim = 5000
    intrinsic_dim = 50
    num_points = 10000
    num_queries = 5
    num_neighbours = 10  # The k in k-NN

    # Guide for tuning hyperparameters:

    # num_comp_indices trades off accuracy vs. construction and query time - high values lead to more accurate results, but slower construction and querying
    # num_simp_indices trades off accuracy vs. construction and query time - high values lead to more accurate results, but slower construction and querying; if num_simp_indices is increased, may need to increase num_comp_indices
    # num_levels trades off construction time vs. query time - higher values lead to faster querying, but slower construction; if num_levels is increased, may need to increase query_field_of_view and construction_field_of_view
    # construction_field_of_view trades off accuracy/query time vs. construction time - higher values lead to *slightly* more accurate results and/or *slightly* faster querying, but *slightly* slower construction
    # construction_prop_to_retrieve trades off acrruacy vs. construction time - higher values lead to *slightly* more accurate results, but slower construction
    # query_field_of_view trades off accuracy vs. query time - higher values lead to more accurate results, but *slightly* slower querying
    # query_prop_to_retrieve trades off accuracy vs. query time - higher values lead to more accurate results, but slower querying

    num_comp_indices = 2
    num_simp_indices = 7
    num_levels = 2
    construction_field_of_view = 10
    construction_prop_to_retrieve = 0.002
    query_field_of_view = 100
    query_prop_to_retrieve = 0.05

    print("Generating Data... ")
    t0 = time()
    data_and_queries = gen_data(dim, intrinsic_dim, num_points + num_queries)
    data = np.copy(data_and_queries[:num_points, :])
    queries = data_and_queries[num_points:, :]

    print("Took %.4fs" % (time() - t0))

    print("Constructing Data Structure... ")
    t0 = time()

    dci_db = DCI(dim, num_comp_indices, num_simp_indices)
    dci_db.add(data,
               num_levels=num_levels,
               field_of_view=construction_field_of_view,
               prop_to_retrieve=construction_prop_to_retrieve)

    print("Took %.4fs" % (time() - t0))

    print("Querying... ")
    t0 = time()

    nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query(
        queries,
        num_neighbours=num_neighbours,
        field_of_view=query_field_of_view,
        prop_to_retrieve=query_prop_to_retrieve,
        blind=True)

    print("Took %.4fs" % (time() - t0))
    print(nearest_neighbour_idx)
    print(nearest_neighbour_dists)
class IMLE():
    def __init__(self, z_dim, Sx_dim):
        self.z_dim = z_dim
        self.Sx_dim = Sx_dim
        self.model = ConvolutionalImplicitModel(z_dim+Sx_dim, 0.5).cuda()
        self.model.apply(self.model.get_initializer())
        self.dci_db = None

#-----------------------------------------------------------------------------------------------------------
        # load pre-trained model
        state_dict = torch.load("../net_weights_2D_scattering_J=6_L=2_times=10.pth")
        self.model.load_state_dict(state_dict)

#-----------------------------------------------------------------------------------------------------------
    #def train(self, data_np, data_Sx, name_JL, base_lr=1e-4, batch_size=128, num_epochs=3000,\
    #         decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100):
    def train(self, data_np, data_Sx, name_JL, base_lr=1e-5, batch_size=128, num_epochs=3000,\
             decay_step=25, decay_rate=0.98, staleness=100, num_samples_factor=100):

        # define metric
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # truncate data to fit the batch size
        num_data = num_batches*batch_size
        data_np = data_np[:num_data]
        data_Sx = data_Sx[:num_data]

        # make it in 1D data image for DCI
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
        # make empty array to store results
        samples_predict = np.empty(data_np.shape)
        samples_np = np.empty((num_samples_factor,)+data_np.shape[1:])

#-----------------------------------------------------------------------------------------------------------
        # make global torch variables
        data_all = torch.from_numpy(data_np).float().cuda()
        Sx = torch.from_numpy(np.repeat(data_Sx,num_samples_factor,axis=0)).float().cuda()

#-----------------------------------------------------------------------------------------------------------
        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)


#=============================================================================================================
        # train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate ** (epoch // decay_step)
                optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
            # update the closest models routintely
            if epoch % staleness == 0:

                # draw random z
                z = torch.randn(num_data*num_samples_factor, self.z_dim, 1, 1).cuda()
                z_Sx_all = torch.cat((z, Sx), axis=1)

                # find the closest object for individual data
                nearest_indices = np.empty((num_data)).astype("int")

                for i in range(num_data):
                    samples = self.model(z_Sx_all[i*num_samples_factor:(i+1)*num_samples_factor])
                    samples_np[:] = samples.cpu().data.numpy()
                    samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
                    # find the nearest neighbours
                    self.dci_db.reset()
                    self.dci_db.add(np.copy(samples_flat_np),\
                                    num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                    nearest_indices_temp, _ = self.dci_db.query(data_flat_np[i:i+1],\
                                        num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                    nearest_indices[i] = nearest_indices_temp[0][0] + i*num_samples_factor

                if epoch == 0:
                    print(np.percentile(data_flat_np, 25), np.percentile(data_flat_np, 50), np.percentile(data_flat_np, 75))
                    print(np.percentile(samples_flat_np, 25), np.percentile(samples_flat_np, 50), np.percentile(samples_flat_np, 75))

                # restrict latent parameters to the nearest neighbour
                z_Sx = z_Sx_all[nearest_indices]


#=============================================================================================================
            # gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):
                self.model.zero_grad()
                cur_samples = self.model(z_Sx[i*batch_size:(i+1)*batch_size])

                # save the mock sample
                if (epoch+1) % staleness == 0:
                    samples_predict[i*batch_size:(i+1)*batch_size] = cur_samples.cpu().data.numpy()

                # gradient descent
                loss = loss_fn(cur_samples, data_all[i*batch_size:(i+1)*batch_size])
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

#-----------------------------------------------------------------------------------------------------------
            # save the mock sample
            if (epoch+1) % staleness == 0:
                #np.savez("../results_2D_times=10_J=4_L=2_epoch=" + str(epoch) +  ".npz", data_np=data_np,\
                #                z_Sx_np=z_Sx.cpu().data.numpy(),\
                #                samples_np=samples_predict)

                # make random mock
                samples_random = self.model(z_Sx_all[:10**4][::100]).cpu().data.numpy()
                np.savez("../results_2D_random_times=10_" + name_JL + "_epoch=" + str(epoch) +  ".npz",\
                          samples_np=samples_random,\
                          mse_err=err / num_batches)

                # save network
                torch.save(self.model.state_dict(), '../net_weights_2D_times=10_' + name_JL + '_epoch=' \
                             + str(epoch) + '.pth')
Ejemplo n.º 3
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        #self.model = ConvolutionalImplicitModel(z_dim).cpu()
        self.model = BigGAN.Generator(resolution=Resolution,
                                      dim_z=z_dim,
                                      n_classes=Classes).cpu()
        self.dci_db = None

    def train(self,
              data_np,
              label_np,
              hyperparams,
              shuffle_data=True,
              path='results'):
        loss_fn = nn.MSELoss().cpu()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor

        grid_size = (5, 5)
        grid_z = torch.randn(np.prod(grid_size), self.z_dim).cpu()
        grid_y = torch.randint(low=0,
                               high=Classes,
                               size=(np.prod(grid_size), ),
                               dtype=torch.int64).cpu()

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]
            label_np = label_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):
            if epoch % 10 == 0:
                print('Saving net_weights.pth...')
                torch.save(self.model.state_dict(),
                           os.path.join(path, 'net_weights.pth'))
            if epoch % 5 == 0:
                print('Saving grid...')
                save_grid(path=path,
                          index=epoch,
                          count=np.prod(grid_size),
                          imle=self,
                          z=grid_z,
                          y=grid_y,
                          z_dim=self.z_dim)
                print('Saved')

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim))
                y_np = np.empty((num_samples * batch_size, 1), dtype=np.int64)
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size,
                                    self.z_dim,
                                    requires_grad=False).cpu()
                    y = torch.randint(low=0,
                                      high=Classes,
                                      size=(batch_size, 1),
                                      dtype=torch.int64,
                                      requires_grad=False).cpu()
                    samples = self.model(z, self.model.shared(y))
                    #import pdb; pdb.set_trace()
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    y_np[i * batch_size:(i + 1) *
                         batch_size] = y.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np, (samples_np.shape[0],
                                 np.prod(samples_np.shape[1:]))).copy()

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)
                y_np = y_np[nearest_indices]

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cpu()
                cur_y = torch.from_numpy(y_np[i * batch_size:(i + 1) *
                                              batch_size]).long().cpu()
                cur_data = torch.from_numpy(data_np[i * batch_size:(i + 1) *
                                                    batch_size]).float().cpu()
                cur_samples = self.model(cur_z, self.model.shared(cur_y))
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
Ejemplo n.º 4
0
def get_code_for_data(model, data, opt):
    options = opt['train']

    dci_num_comp_indices = int(options['dci_num_comp_indices'])
    dci_num_simp_indices = int(options['dci_num_simp_indices'])
    dci_num_levels = int(options['dci_num_levels'])
    dci_construction_field_of_view = int(options['dci_construction_field_of_view'])
    dci_query_field_of_view = int(options['dci_query_field_of_view'])
    dci_prop_to_visit = float(options['dci_prop_to_visit'])
    dci_prop_to_retrieve = float(options['dci_prop_to_retrieve'])
    sample_perturbation_magnitude = float(options['sample_perturbation_magnitude'])

    code_nc = int(opt['network_G']['in_code_nc'])
    pull_num_sample_per_img = int(options['num_code_per_img'])

    show_message = False if 'show_message' not in options else options['show_message']

    pull_gen_img = data['LR']
    real_gen_img = data['HR']
    pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3])

    if show_message:
        print("Generating Pull Samples")
    data_length = pull_gen_img.shape[0]

    out_feature_shape = model.netF(data['HR'][:1]).shape[1:]
    # initialize dci db
    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0 and show_message:
            print_without_newline(
                '\rFinding first stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        if 'zero_code' in options and options['zero_code']:
            pull_gen_code_pool_0 = torch.zeros(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])
        elif 'rand_code' in options and options['rand_code']:
            pull_gen_code_pool_0 = torch.rand(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])
        else:
            pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2],
                                               pull_gen_img.shape[3])

        pull_img = pull_gen_img[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)

        # cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1]}
        cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1].expand(
            (pull_num_sample_per_img,) + real_gen_img.shape[1:])}

        model.feed_data(cur_data, code=pull_gen_code_pool_0)

        feature_output = model.get_features()

        pull_gen_features_pool = feature_output['gen_feat']
        target_feature = feature_output['gen_feat']

        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(
            pull_gen_features_pool.shape[1:])).double().numpy().copy()
        target_feature = target_feature.reshape(-1, np.prod(target_feature.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool,
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature.numpy(),
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db for next query
        pull_samples_dci_db.clear()

    if show_message:
        print('\rFinding first stack code: Processed %d out of %d instances' % (
            data_length, data_length))

    if 'zero_code' in options and options['zero_code']:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.zeros(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])
    elif 'rand_code' in options and options['rand_code']:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.rand(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])
    else:
        pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                       pull_gen_img.shape[2],
                                                                       pull_gen_img.shape[3])

    return pull_gen_code_0
Ejemplo n.º 5
0
def get_code_for_data_three(model, data, opt):
    options = opt['train']

    dci_num_comp_indices = int(options['dci_num_comp_indices'])
    dci_num_simp_indices = int(options['dci_num_simp_indices'])
    dci_num_levels = int(options['dci_num_levels'])
    dci_construction_field_of_view = int(options['dci_construction_field_of_view'])
    dci_query_field_of_view = int(options['dci_query_field_of_view'])
    dci_prop_to_visit = float(options['dci_prop_to_visit'])
    dci_prop_to_retrieve = float(options['dci_prop_to_retrieve'])
    sample_perturbation_magnitude = float(options['sample_perturbation_magnitude'])

    code_nc = int(opt['network_G']['in_code_nc'])
    pull_num_sample_per_img = int(options['num_code_per_img'])

    pull_gen_img = data['LR']
    d1_gen_img = data['D1']
    d2_gen_img = data['D2']
    real_gen_img = data['HR']

    pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2)

    forward_bs = 10

    print("Generating Pull Samples")
    data_length = pull_gen_img.shape[0]

    out_feature_shape = model.netF(data['D1'][:1])[-1].shape[1:]

    # ============ ADD POINT ==================
    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding first stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2)
        pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4,
                                           pull_gen_img.shape[3] * 4)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat_D1'].double().numpy())
        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat_D1']
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])).double().numpy().copy()

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding first stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    # ============ ADD POINT ==================
    pull_gen_code_1 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4)
    out_feature_shape = model.netF(data['D2'][:1])[-1].shape[1:]

    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding second stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        # ============ ADD POINT ==================
        pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4,
                                           pull_gen_img.shape[3] * 4)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat_D2'].double().numpy())

        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat_D2'].double().numpy().copy()
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:]))

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_1[sample_index, :] = pull_gen_code_pool_1[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding second stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    # ============ ADD POINT ==================
    pull_gen_code_2 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8)
    out_feature_shape = model.netF(data['HR'][:1])[-1].shape[1:]

    pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices)

    for sample_index in range(data_length):
        if (sample_index + 1) % 10 == 0:
            print_without_newline(
                '\rFinding third stack code: Processed %d out of %d instances' % (
                    sample_index + 1, data_length))
        # ============ ADD POINT ==================
        pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_1 = pull_gen_code_1[sample_index].expand(pull_num_sample_per_img, -1, -1, -1)
        pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8,
                                           pull_gen_img.shape[3] * 8)
        pull_gen_features_pool = []
        for i in range(0, pull_num_sample_per_img, forward_bs):
            pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1)
            pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1)

            cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2}
            start = i
            end = i + forward_bs

            model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end],
                                            pull_gen_code_pool_2[start:end]])
            feature_output = model.get_features()

            pull_gen_features_pool.append(feature_output['gen_feat'].double().numpy())

        pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0)
        pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:]))

        pull_samples_dci_db.add(pull_gen_features_pool.copy(),
                                num_levels=dci_num_levels,
                                field_of_view=dci_construction_field_of_view,
                                prop_to_visit=dci_prop_to_visit,
                                prop_to_retrieve=dci_prop_to_retrieve)
        target_feature = feature_output['real_feat'].double().numpy().copy()
        target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:]))

        pull_sample_idx_for_img, _ = pull_samples_dci_db.query(
            target_feature,
            num_neighbours=1,
            field_of_view=dci_query_field_of_view,
            prop_to_visit=dci_prop_to_visit,
            prop_to_retrieve=dci_prop_to_retrieve)

        pull_gen_code_2[sample_index, :] = pull_gen_code_pool_2[int(pull_sample_idx_for_img[0][0]), :]
        # clear the db
        pull_samples_dci_db.clear()

    print('\rFinding third stack code: Processed %d out of %d instances' % (
        data_length, data_length))

    pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 2,
                                                                   pull_gen_img.shape[3] * 2)
    pull_gen_code_1 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 4,
                                                                   pull_gen_img.shape[3] * 4)
    pull_gen_code_2 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc,
                                                                   pull_gen_img.shape[2] * 8,
                                                                   pull_gen_img.shape[3] * 8)

    return [pull_gen_code_0, pull_gen_code_1, pull_gen_code_2]
Ejemplo n.º 6
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor
        num_outputs = hyperparams.num_outputs

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        input_shape = [64, 1, 1, self.z_dim]
        net = get_model(input_shape)
        train_weights = net.trainable_weights

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = tf.optimizers.Adam(learning_rate=lr,
                                               beta_1=0.5,
                                               beta_2=0.999)
                # Optimizer API changes. betas->beta_1, beta_2 & remove weight_decay=1e-5

            if epoch % hyperparams.staleness == 0:
                net.eval()  #declare now in eval mode.
                z_np = np.empty((num_samples * batch_size, 1, 1, self.z_dim))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = tf.random.normal([batch_size, 1, 1, self.z_dim])
                    samples = net(z)
                    z_np[i * batch_size:(i + 1) * batch_size] = z.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                net.eval()
                cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) *
                                                  batch_size],
                                             dtype=tf.float32)
                cur_data = tf.convert_to_tensor(
                    data_np[i * batch_size:(i + 1) * batch_size],
                    dtype=tf.float32)
                cur_samples = net(cur_z)
                _loss = tl.cost.mean_squared_error(cur_data,
                                                   cur_samples,
                                                   is_mean=False)
                err += _loss

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            for i in range(num_batches):
                net.train()
                with tf.GradientTape() as tape:
                    cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) *
                                                      batch_size],
                                                 dtype=tf.float32)
                    cur_data = tf.convert_to_tensor(
                        data_np[i * batch_size:(i + 1) * batch_size],
                        dtype=tf.float32)
                    cur_samples = net(cur_z)
                    _loss = tl.cost.mean_squared_error(cur_data,
                                                       cur_samples,
                                                       is_mean=False)
                grad = tape.gradient(_loss, train_weights)
                optimizer.apply_gradients(zip(grad, train_weights))

            #debug. extreamly lower the speed. but worth doing to visualize the process. (following 4 lines can be deleted directly)
            #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32)
            #cur_samples = net(cur_z)
            #pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))
            #save_img(1, pics)
            cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim])
            cur_samples = net(cur_z)
            pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))
            save_img(2, pics)

        #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32)
        cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim])
        cur_samples = net(cur_z)
        pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))

        save_img(batch_size, pics)
Ejemplo n.º 7
0
class IMLE():
    def __init__(self, z_dim, Sx_dim):
        self.z_dim = z_dim
        self.Sx_dim = Sx_dim
        self.model = ConvolutionalImplicitModel(z_dim + Sx_dim).cuda()
        self.dci_db = None

#-----------------------------------------------------------------------------------------------------------
    def train(self, data_np, data_Sx, base_lr=1e-4, batch_size=256, num_epochs=3000,\
             decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100):

        # define metric
        # loss_fn = nn.MSELoss().cuda()
        loss_fn = nn.L1Loss().cuda()
        # loss_fn = nn.BCELoss().cuda()

        self.model.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # truncate data to fit the batch size
        num_data = num_batches * batch_size
        data_np = data_np[:num_data]
        data_Sx = data_Sx[:num_data]

        #-----------------------------------------------------------------------------------------------------------
        # make empty array to store results
        samples_predict = np.empty(data_np.shape)

        samples_np = np.empty((num_samples_factor, ) + data_np.shape[1:])
        # samples_np = np.empty((num_data*num_samples_factor,)+data_np.shape[1:])

        nearest_indices = np.empty((num_data)).astype("int")

        # make global torch variables
        data_all = torch.from_numpy(data_np).float().cuda()
        Sx = torch.from_numpy(np.repeat(data_Sx, num_samples_factor,
                                        axis=0)).float().cuda()

        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

#=============================================================================================================
# train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate**(epoch // decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
# update the closest models routintely
            if epoch % staleness == 0:

                # draw random z
                z = torch.randn(num_data * num_samples_factor,
                                self.z_dim).cuda()
                #z_Sx_all = torch.cat((z, Sx), axis=1)
                z_Sx_all = torch.cat((z, Sx), axis=1)[:, :, None]

                #-----------------------------------------------------------------------------------------------------------
                # find the closest object for individual data
                nearest_indices = np.empty((num_data)).astype("int")

                for i in range(num_data):
                    samples = self.model(
                        z_Sx_all[i * num_samples_factor:(i + 1) *
                                 num_samples_factor])
                    samples_np[:] = samples.cpu().data.numpy()

                    # find the nearest neighbours
                    self.dci_db.reset()
                    self.dci_db.add(np.copy(samples_np),\
                                    num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                    nearest_indices_temp, _ = self.dci_db.query(data_np[i:i+1],\
                                        num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                    nearest_indices[i] = nearest_indices_temp[0][
                        0] + i * num_samples_factor

#-----------------------------------------------------------------------------------------------------------
# # find the closest object for individual data
# samples = self.model(z_Sx_all)
# samples_np[:] = samples.cpu().data.numpy()
#
# # find the nearest neighbours
# self.dci_db.reset()
# self.dci_db.add(np.copy(samples_np),\
#                 num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
# nearest_indices_temp, _ = self.dci_db.query(data_np,\
#                 num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
# nearest_indices[:] = nearest_indices_temp

#-----------------------------------------------------------------------------------------------------------
# restrict latent parameters to the nearest neighbour
                z_Sx = z_Sx_all[nearest_indices]

#=============================================================================================================
# gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):
                self.model.zero_grad()
                cur_samples = self.model(z_Sx[i * batch_size:(i + 1) *
                                              batch_size])

                # save the mock sample
                if (epoch + 1) % staleness == 0:
                    samples_predict[i * batch_size:(
                        i + 1) * batch_size] = cur_samples.cpu().data.numpy()

                # gradient descent
                loss = loss_fn(cur_samples,
                               data_all[i * batch_size:(i + 1) * batch_size])
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            #-----------------------------------------------------------------------------------------------------------
            # save the mock sample
            if (epoch + 1) % staleness == 0:

                # save closet models
                np.savez("../results_spectra_deconv_256x2_" + str(epoch) +  ".npz", data_np=data_np,\
                                               z_Sx_np=z_Sx.cpu().data.numpy(),\
                                               samples_np=samples_predict)

                np.savez("../mse_err_deconv_256x2_" + str(epoch) +  ".npz",\
                                                mse_err=err/num_batches)

                # save network
                torch.save(
                    self.model.state_dict(),
                    '../net_weights_spectra_deconv_256x2_epoch=' + str(epoch) +
                    '.pth')
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1).cuda()
                    samples = self.model(z)
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cuda()
                cur_data = torch.from_numpy(
                    data_np[i * batch_size:(i + 1) *
                            batch_size]).float().cuda()
                cur_samples = self.model(cur_z)
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
def runknn(sz=4,
           num_points=202599,
           num_queries=1,
           num_comp_indices=2,
           num_simp_indices=7,
           num_outer_iterations=202599,
           max_num_candidates=5000,
           num_neighbours=10,
           patch_size=5):
    num_levels = 3
    construction_field_of_view = 10
    construction_prop_to_retrieve = 0.02
    query_field_of_view = 12000
    query_prop_to_retrieve = 1.0
    dim = patch_size * patch_size * 5
    #setmodulename('\"../church_npy/church_%d.npy\"'%sz)

    #data_and_queries = get_img(dim, num_points + num_queries)
    #data = np.load("../church_npy/church_%d.npy"%sz,mmap_mode='r')#get_img(dim, num_points)#np.copy(data_and_queries[:num_points,:])
    down = int(math.log(256 / sz, 2))

    dci_db = DCI(dim, num_comp_indices, num_simp_indices)
    st = time.time()
    print("before adding")
    dci_db.add(num_points,
               "../church_npy/church_%d_vgg_12_5.npy" % sz,
               num_levels=num_levels,
               field_of_view=construction_field_of_view,
               prop_to_retrieve=construction_prop_to_retrieve,
               load_from_file=1)
    print("construction time:", time.time() - st)
    imgid = 270
    flip = 0
    datamat = np.load("../church_npy/church_%d.npy" % 128, mmap_mode='r')
    rawimg = imread("../results_churchoutdoor/fake_%04d.png" % imgid)

    if flip:
        rawimg = np.flip(rawimg, 1)
    #im = Image.fromarray(target, 'RGB')
    for j in range(1):
        rawimg = np.mean(np.concatenate([
            rawimg[0::2, 0::2, None], rawimg[0::2, 1::2, None],
            rawimg[1::2, 0::2, None], rawimg[1::2, 1::2, None]
        ],
                                        axis=2),
                         axis=2)
    rawimg = rawimg.astype(np.uint8)
    if flip:
        fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12_flip.npy" % sz,
                           mmap_mode='r')
    else:
        fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12.npy" % sz,
                           mmap_mode='r')
    mularr = np.load("../church_npy/rpvec_64_5.npy")
    fakeimgs_rp = fakeimgs[imgid].copy()
    print(fakeimgs_rp.shape)
    fakeimg_rp = np.dot(fakeimgs_rp, mularr)
    del fakeimgs
    del mularr
    minx = 0
    miny = 0
    maxx = 128
    maxy = 128
    ambient_dim = 32 * 32 * 5
    numx = (max(minx + 1, maxx - 32 + 1) - minx + 32 - 1) / 32
    numy = (max(miny + 1, maxy - 32 + 1) - miny + 32 - 1) / 32
    print(numx, numy)
    queries = np.empty([25 * 25, ambient_dim], dtype=np.float32)
    d = 32
    st = 0
    for j in range(miny, max(miny + 1, maxy - 32 + 1), 4):
        for k in range(minx, max(minx + 1, maxx - 32 + 1), 4):
            eyeimg = fakeimg_rp[j:j + d,
                                k:k + d, :].flatten().astype(np.float32)
            queries[st] = eyeimg / np.linalg.norm(eyeimg) * 255
            st += 1
    print(queries.shape, queries.dtype, st)
    st = time.time()
    num_neighbours = 200
    query_field_of_view = 10000  #11000
    query_prop_to_retrieve = 1.0
    nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query(
        queries,
        num_neighbours,
        field_of_view=query_field_of_view,
        prop_to_retrieve=query_prop_to_retrieve,
        blind=False)
    print("query time:", time.time() - st)
    finaldist = np.array(nearest_neighbour_dists)
    rawidx = np.array(nearest_neighbour_idx)
    print(rawidx.shape)
    finalidx = rawidx / (25 * 25)
    finalpos = np.empty([rawidx.shape[0], 200, 2], dtype=np.int)
    for i in range(rawidx.shape[0]):
        for j in range(200):
            offset = rawidx[i, j] % (25 * 25)
            finalpos[i, j, 0] = offset / 25 * 4
            finalpos[i, j, 1] = offset % 25 * 4
    np.savez("fake_%04d_all_overlap.npz" % imgid,
             finalidx=finalidx,
             finaldist=finaldist,
             finalpos=finalpos)

    return
    queries = get_query(dim, 1, down, sz, patch_size,
                        patch_size)  #data_and_queries[num_points:,:]
    print(queries.shape, queries.dtype)
    queries = queries[0:2]
    st = time.time()
    nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query(
        queries,
        num_neighbours,
        field_of_view=query_field_of_view,
        prop_to_retrieve=query_prop_to_retrieve,
        blind=False)
    print("query time:", time.time() - st)
    #gen_res(num_points, queries, nearest_neighbour_idx, down, sz)
    #resfile = "../save_%d/res.txt"%(sz)
    #f = open(resfile,'w')
    print(np.array(nearest_neighbour_idx)[:, 0])
    print(np.array(nearest_neighbour_dists)[:, 0])
    np.save("churchidx_18_v7.npy", np.array(nearest_neighbour_idx))
    np.save("churchdist_18_v7.npy", np.array(nearest_neighbour_dists))
    dci_db.clear()
Ejemplo n.º 10
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor  # number of generated samples

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            # Data generation step - do you re-sample training steps?
            # It does not seem like it is selecting the samples S? It always uses the original dataset data_np
            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1).cuda()
                    samples = self.model(z)
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            # z_np consists of z values whose value makes up the closest sample to each data point
            # But why do we do it this way?
            # We want to compare the real data point x_i with the closest generated point.
            # Call this point hat{x}_{sigma(i)}, and this was generated by the random noise, say, G_{theta}(z_i).
            # Once we have determined this point, how can we change G such that hat{x}_{sigma(i)} is closer to x_i?
            # We minimize || G_{theta)(z_i) - x_i||
            # What happens in the conditional case, where we condition on a random variable A?
            # The idea is to generate m different samples for each a in A, and then out of
            # m samples, find the one that is closest to the data point

            err = 0.
            # batch-gradient descent
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cuda()
                cur_data = torch.from_numpy(
                    data_np[i * batch_size:(i + 1) *
                            batch_size]).float().cuda()
                cur_samples = self.model(cur_z)
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
Ejemplo n.º 11
0
class CoolSystem(pl.LightningModule):

    def __init__(self, z_dim, hyperparams, shuffle_data=True):
        super(CoolSystem, self).__init__()
        # not the best model...
        #self.l1 = torch.nn.Linear(28 * 28, 10)
        self.z_dim = z_dim
        self.z_grid = None
        self.hyperparams = hyperparams
        self.shuffle_data = shuffle_data
        self.dci_db = None
        self.model = ConvolutionalImplicitModel(z_dim)
        #self.model = Generator128(z_dim, n_class=10)
        #self.squeezeNet = NetPerLayer()
        self.squeezeNet = None
        self.loss_fn = nn.MSELoss()
        self._step = 0

    def regen(self, batch):
        #import pdb; pdb.set_trace()

        imgs, labels = batch
        data_np = imgs.numpy()
        #import pdb;  pdb.set_trace()
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))
        
        self.data_np = data_np
        self.data_flat_np = data_flat_np
        hyperparams = self.hyperparams
        
        batch_size = hyperparams.batch_size
        #num_batches = data_np.shape[0] // batch_size
        num_batches = 1
        num_samples = num_batches * hyperparams.num_samples_factor
        #import pdb; pdb.set_trace()
      
        z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
        samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:])
        for i in range(num_samples):
          z = torch.randn(batch_size, self.z_dim, 1, 1).cpu()
          samples = self.model(z, labels)
          z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy()
          samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy()
        
        samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy()
        
        if self.dci_db is None:
          self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)
        
        self.dci_db.reset()
        self.dci_db.add(samples_flat_np, num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
        nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
        nearest_indices = np.array(nearest_indices)[:,0]
        
        z_np = z_np[nearest_indices]
        z_np += 0.01*np.random.randn(*z_np.shape)
        self.z_np = z_np
        if self.z_grid is None:
          self.z_grid = self.z_np.copy() * 0.7
        
        del samples_np, samples_flat_np

        return z_np

    # def forward(self, x):
    #     return torch.relu(self.l1(x.view(x.size(0), -1)))

    def forward(self, z, class_id=None):
        cur_samples = self.model(z, class_id)
        return cur_samples

    def training_step(self, batch, batch_idx):
        # REQUIRED
        #x, y = batch
        #y_hat = self.forward(x)
        #loss = F.cross_entropy(y_hat, y)

        #batch_size = batch.shape[0]
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size
        i = batch_idx
        data_np = self.data_np

        #z_np = self.regen()
        z_np = self.z_np
        z_grid = self.z_grid
        #cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cpu()
        #imgTarget = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cpu()
        cur_z = torch.from_numpy(z_np).float().cpu()
        imgTarget = torch.from_numpy(data_np).float().cpu()
        print(cur_z.shape, batch_idx)
        def nimg(img):
          #img = img + 1
          #img = img / 2
          #import pdb; pdb.set_trace()
          img = torch.clamp(img, 0, 1)
          return img
        imgInput = batch[0]
        imgLabels = batch[1]
        self.logger.experiment.add_image('imgInput', torchvision.utils.make_grid(nimg(imgInput)), self._step)
        imgOutput = self.forward(cur_z, imgLabels)
        self.logger.experiment.add_image('imgOutput', torchvision.utils.make_grid(nimg(imgOutput)), self._step)
        if self._step % me.interp_every == 0 and self._step > -1 and True:
          cur_zgrid = torch.from_numpy(z_grid).float().cpu()
          imgGrid = self.forward(cur_zgrid, imgLabels)
          self.logger.experiment.add_image('imgGrid', torchvision.utils.make_grid(nimg(imgGrid)), self._step)
          #print('circle...')
          #import pdb; pdb.set_trace()
          imgs = me.circle_interpolation(self.model, imgInput.shape[0])
          #imgs = np.array(imgs)
          #imgs = torch.from_numpy(imgs).float().cpu()
          self.logger.experiment.add_image('imgInterp', torchvision.utils.make_grid(nimg(imgs)), self._step)
          #import pdb; pdb.set_trace()
        if self.squeezeNet is not None:
          #print('squeezeNet(imgTarget)...')
          activationsTarget = self.squeezeNet(imgTarget.repeat(1,3,1,1))
          #print('squeezeNet(imgOutput)...')
          activationsOutput = self.squeezeNet(imgOutput.repeat(1,3,1,1))
          #print('loss...')
          featLoss = None
          #for actTarget, actOutput in zip(activationsTarget[1:3], activationsOutput[1:3]):
          #for actTarget, actOutput in tqdm.tqdm(list(zip(activationsTarget, activationsOutput))):
          for actTarget, actOutput in zip(activationsTarget, activationsOutput):
            #l = F.mse_loss(actTarget, actOutput)
            #l = torch.abs(actTarget - actOutput).sum()
            #l = F.l1_loss(actTarget, actOutput)

            l = -pearsonr(actTarget.view(-1), actOutput.view(-1))
            if featLoss is None:
              featLoss = l
            else:
              featLoss += l
        else:
          featLoss = 0.0
        #pixelLoss = F.mse_loss(imgOutput, imgTarget)
        pixelLoss = self.loss_fn(imgOutput, imgTarget)
        loss = featLoss + pixelLoss
        lr = self.lr_fn(self.current_epoch)
        tensorboard_logs = {'train_loss': loss, 'lr': lr, 'epoch': self.current_epoch}
        self._step += 1
        return {'loss': loss, 'log': tensorboard_logs}

    def number_interpolation(self, count, lo, hi):
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size

        def gen_latent(pos):
          z = gen_func(pos)
          z = z.reshape([1,-1,1,1])
          return z

#     def validation_step(self, batch, batch_idx):
#         # OPTIONAL
#         x, y = batch
#         y_hat = self.forward(x)
#         return {'val_loss': F.cross_entropy(y_hat, y)}

#     def validation_end(self, outputs):
#         # OPTIONAL
#         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         tensorboard_logs = {'val_loss': avg_loss}
#         return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
#     def test_step(self, batch, batch_idx):
#         # OPTIONAL
#         x, y = batch
#         y_hat = self.forward(x)
#         return {'test_loss': F.cross_entropy(y_hat, y)}

#     def test_end(self, outputs):
#         # OPTIONAL
#         avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
#         tensorboard_logs = {'test_loss': avg_loss}
#         return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        #return torch.optim.Adam(self.parameters(), lr=0.0004)
        #epoch = 0
        hyperparams = self.hyperparams
        lr = hyperparams.base_lr # * hyperparams.decay_rate ** (epoch // hyperparams.decay_step)
        self.lr_fn = lambda epoch: hyperparams.decay_rate ** (epoch // hyperparams.decay_step)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, amsgrad=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-5)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.0, 0.999))
        from torch.optim.lr_scheduler import LambdaLR
        scheduler = LambdaLR(optimizer, lr_lambda=[self.lr_fn])
        return [optimizer], [scheduler]
        

    def prepare_data(self):
        #MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        #MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

        #self.data_np = np.random.randn(128, 1, 28, 28)
        # self.data_np = np.array([img.numpy() for img, label in dataset])
        # print('ready')
        # data_np = self.data_np
        # hyperparams = self.hyperparams

        # if self.shuffle_data:
        #     data_ordering = np.random.permutation(data_np.shape[0])
        #     data_np = data_np[data_ordering]

        # batch_size = hyperparams.batch_size
        # num_batches = data_np.shape[0] // batch_size
        # num_samples = num_batches * hyperparams.num_samples_factor
        
        # self.data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

    def on_batch_start(self, batch):
        self.regen(batch)
        

    def train_dataloader(self):
        # REQUIRED
        #dataset = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
        #dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
        #loader = DataLoader(dataset, batch_size=32)

        # Downloading/Louding CIFAR10 data
        trainset  = CIFAR10(root=os.getcwd(), train=True , download=False)#, transform = transform_with_aug)
        testset   = CIFAR10(root=os.getcwd(), train=False, download=False)#, transform = transform_no_aug)
        classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}

        # Separating trainset/testset data/label
        x_train  = trainset.data
        x_test   = testset.data
        y_train  = trainset.targets
        y_test   = testset.targets

        # Define a function to separate CIFAR classes by class index

        def get_class_i(x, y, i):
            """
            x: trainset.train_data or testset.test_data
            y: trainset.train_labels or testset.test_labels
            i: class label, a number between 0 to 9
            return: x_i
            """
            # Convert to a numpy array
            y = np.array(y)
            # Locate position of labels that equal to i
            pos_i = np.argwhere(y == i)
            # Convert the result into a 1-D list
            pos_i = list(pos_i[:,0])
            # Collect all data that match the desired label
            x_i = [x[j] for j in pos_i]
            
            return x_i

        class DatasetMaker(Dataset):
            def __init__(self, datasets, transformFunc = transforms.ToTensor()):
                """
                datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes
                """
                self.datasets = datasets
                self.lengths  = [len(d) for d in self.datasets]
                self.transformFunc = transformFunc
            def __getitem__(self, i):
                class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i)
                img = self.datasets[class_label][index_wrt_class]
                if self.transformFunc:
                  img = self.transformFunc(img)
                return img, class_label

            def __len__(self):
                return sum(self.lengths)
            
            def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False):
                """
                Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to.
                """
                # Which class/bin does i fall into?
                accum = np.add.accumulate(bin_sizes)
                if verbose:
                    print("accum =", accum)
                bin_index  = len(np.argwhere(accum <= absolute_index))
                if verbose:
                    print("class_label =", bin_index)
                # Which element of the fallent class/bin does i correspond to?
                index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index]
                if verbose:
                    print("index_wrt_class =", index_wrt_class)

                return bin_index, index_wrt_class

        # ================== Usage ================== #

        # Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset
        cat_dog_trainset = \
            DatasetMaker(
                [get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])]
                #transform_with_aug
            )
        cat_dog_testset  = \
            DatasetMaker(
                [get_class_i(x_test , y_test , classDict['cat']), get_class_i(x_test , y_test , classDict['dog'])]
                #transform_no_aug
            )

        kwargs = {'num_workers': 2, 'pin_memory': False}
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size

        # Create datasetLoaders from trainset and testset
        trainsetLoader   = DataLoader(cat_dog_trainset, batch_size=batch_size, shuffle=True , **kwargs)
        testsetLoader    = DataLoader(cat_dog_testset , batch_size=batch_size, shuffle=False, **kwargs)

        #loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        # self.data_np = np.array([img.numpy() for img, label in dataset])
        #self._train_loader = loader
        #return loader
        return trainsetLoader
Ejemplo n.º 12
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim

        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.model2 = ConvolutionalImplicitModel(z_dim).cuda()

        self.dci_db = None
        self.dci_db2 = None

#-----------------------------------------------------------------------------------------------------------
    def train(self, data_np, data_np2, base_lr=1e-3, batch_size=64, num_epochs=10000,\
              decay_step=25, decay_rate=1.0, staleness=500, num_samples_factor=100):

        # define metric
        loss_fn = nn.MSELoss().cuda()

        # make model trainable
        self.model.train()
        self.model2.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # true data to mock sample
        num_samples = num_batches * num_samples_factor

#-----------------------------------------------------------------------------------------------------------
        # make it in 1D data image for DCI
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)
        if self.dci_db2 is None:
            self.dci_db2 = DCI(np.prod(data_np2.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)

#-----------------------------------------------------------------------------------------------------------
        # train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate ** (epoch // decay_step)
                optimizer = optim.Adam(list(self.model.parameters()) + list(self.model2.parameters()),\
                                        lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
            # re-evaluate the closest models routinely
            if epoch % staleness == 0:

                # initiate numpy array to store latent draws and the associate sample
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1, 1))

                samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:])
                samples_np2 = np.empty((num_samples * batch_size,)+data_np2.shape[1:])

                # make sample (in batch to avoid GPU memory problem)
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1, 1).cuda()

                    samples = self.model(z)
                    samples2 = self.model2(z)

                    z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy()

                    samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy()
                    samples_np2[i*batch_size:(i+1)*batch_size] = samples2.cpu().data.numpy()

                # make 1D images
                samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
                # find the nearest neighbours
                self.dci_db.reset()
                self.dci_db.add(np.copy(samples_flat_np), num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                nearest_indices = np.array(nearest_indices)[:,0]
                z_np = z_np[nearest_indices]

                # add random noise to the latent space to faciliate training
                z_np += 0.01*np.random.randn(*z_np.shape)

                # delete to save Hyperparameters
                del samples_np, samples_np2, samples_flat_np


#=============================================================================================================
            # permute data
            data_ordering = np.random.permutation(data_np.shape[0])

            data_np = data_np[data_ordering]
            data_np2 = data_np2[data_ordering]

            data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

            z_np = z_np[data_ordering]

#-----------------------------------------------------------------------------------------------------------
            # gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):

                # set up backprop
                self.model.zero_grad()
                self.model2.zero_grad()

#-----------------------------------------------------------------------------------------------------------
                # evaluate the models
                cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cuda()

                cur_data = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cuda()
                cur_data2 = torch.from_numpy(data_np2[i*batch_size:(i+1)*batch_size]).float().cuda()

                cur_samples = self.model(cur_z)
                cur_samples2 = self.model2(cur_z)

#-----------------------------------------------------------------------------------------------------------
                # calculate MSE loss of the two images
                loss = loss_fn(cur_samples, cur_data) + loss_fn(cur_samples2, cur_data2)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            # save the mock sample
            if (epoch+1) % staleness == 0:
                np.savez("../results_3D.npz",
                        data_np=data_np,\
                        data_np2=data_np2,\
                        samples_np=self.model(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy(),\
                        samples_np2=self.model2(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy())