class IMLE():
    def __init__(self, z_dim, Sx_dim):
        self.z_dim = z_dim
        self.Sx_dim = Sx_dim
        self.model = ConvolutionalImplicitModel(z_dim+Sx_dim, 0.5).cuda()
        self.model.apply(self.model.get_initializer())
        self.dci_db = None

#-----------------------------------------------------------------------------------------------------------
        # load pre-trained model
        state_dict = torch.load("../net_weights_2D_scattering_J=6_L=2_times=10.pth")
        self.model.load_state_dict(state_dict)

#-----------------------------------------------------------------------------------------------------------
    #def train(self, data_np, data_Sx, name_JL, base_lr=1e-4, batch_size=128, num_epochs=3000,\
    #         decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100):
    def train(self, data_np, data_Sx, name_JL, base_lr=1e-5, batch_size=128, num_epochs=3000,\
             decay_step=25, decay_rate=0.98, staleness=100, num_samples_factor=100):

        # define metric
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # truncate data to fit the batch size
        num_data = num_batches*batch_size
        data_np = data_np[:num_data]
        data_Sx = data_Sx[:num_data]

        # make it in 1D data image for DCI
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
        # make empty array to store results
        samples_predict = np.empty(data_np.shape)
        samples_np = np.empty((num_samples_factor,)+data_np.shape[1:])

#-----------------------------------------------------------------------------------------------------------
        # make global torch variables
        data_all = torch.from_numpy(data_np).float().cuda()
        Sx = torch.from_numpy(np.repeat(data_Sx,num_samples_factor,axis=0)).float().cuda()

#-----------------------------------------------------------------------------------------------------------
        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)


#=============================================================================================================
        # train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate ** (epoch // decay_step)
                optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
            # update the closest models routintely
            if epoch % staleness == 0:

                # draw random z
                z = torch.randn(num_data*num_samples_factor, self.z_dim, 1, 1).cuda()
                z_Sx_all = torch.cat((z, Sx), axis=1)

                # find the closest object for individual data
                nearest_indices = np.empty((num_data)).astype("int")

                for i in range(num_data):
                    samples = self.model(z_Sx_all[i*num_samples_factor:(i+1)*num_samples_factor])
                    samples_np[:] = samples.cpu().data.numpy()
                    samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
                    # find the nearest neighbours
                    self.dci_db.reset()
                    self.dci_db.add(np.copy(samples_flat_np),\
                                    num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                    nearest_indices_temp, _ = self.dci_db.query(data_flat_np[i:i+1],\
                                        num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                    nearest_indices[i] = nearest_indices_temp[0][0] + i*num_samples_factor

                if epoch == 0:
                    print(np.percentile(data_flat_np, 25), np.percentile(data_flat_np, 50), np.percentile(data_flat_np, 75))
                    print(np.percentile(samples_flat_np, 25), np.percentile(samples_flat_np, 50), np.percentile(samples_flat_np, 75))

                # restrict latent parameters to the nearest neighbour
                z_Sx = z_Sx_all[nearest_indices]


#=============================================================================================================
            # gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):
                self.model.zero_grad()
                cur_samples = self.model(z_Sx[i*batch_size:(i+1)*batch_size])

                # save the mock sample
                if (epoch+1) % staleness == 0:
                    samples_predict[i*batch_size:(i+1)*batch_size] = cur_samples.cpu().data.numpy()

                # gradient descent
                loss = loss_fn(cur_samples, data_all[i*batch_size:(i+1)*batch_size])
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

#-----------------------------------------------------------------------------------------------------------
            # save the mock sample
            if (epoch+1) % staleness == 0:
                #np.savez("../results_2D_times=10_J=4_L=2_epoch=" + str(epoch) +  ".npz", data_np=data_np,\
                #                z_Sx_np=z_Sx.cpu().data.numpy(),\
                #                samples_np=samples_predict)

                # make random mock
                samples_random = self.model(z_Sx_all[:10**4][::100]).cpu().data.numpy()
                np.savez("../results_2D_random_times=10_" + name_JL + "_epoch=" + str(epoch) +  ".npz",\
                          samples_np=samples_random,\
                          mse_err=err / num_batches)

                # save network
                torch.save(self.model.state_dict(), '../net_weights_2D_times=10_' + name_JL + '_epoch=' \
                             + str(epoch) + '.pth')
Exemple #2
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor
        num_outputs = hyperparams.num_outputs

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        input_shape = [64, 1, 1, self.z_dim]
        net = get_model(input_shape)
        train_weights = net.trainable_weights

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = tf.optimizers.Adam(learning_rate=lr,
                                               beta_1=0.5,
                                               beta_2=0.999)
                # Optimizer API changes. betas->beta_1, beta_2 & remove weight_decay=1e-5

            if epoch % hyperparams.staleness == 0:
                net.eval()  #declare now in eval mode.
                z_np = np.empty((num_samples * batch_size, 1, 1, self.z_dim))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = tf.random.normal([batch_size, 1, 1, self.z_dim])
                    samples = net(z)
                    z_np[i * batch_size:(i + 1) * batch_size] = z.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                net.eval()
                cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) *
                                                  batch_size],
                                             dtype=tf.float32)
                cur_data = tf.convert_to_tensor(
                    data_np[i * batch_size:(i + 1) * batch_size],
                    dtype=tf.float32)
                cur_samples = net(cur_z)
                _loss = tl.cost.mean_squared_error(cur_data,
                                                   cur_samples,
                                                   is_mean=False)
                err += _loss

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            for i in range(num_batches):
                net.train()
                with tf.GradientTape() as tape:
                    cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) *
                                                      batch_size],
                                                 dtype=tf.float32)
                    cur_data = tf.convert_to_tensor(
                        data_np[i * batch_size:(i + 1) * batch_size],
                        dtype=tf.float32)
                    cur_samples = net(cur_z)
                    _loss = tl.cost.mean_squared_error(cur_data,
                                                       cur_samples,
                                                       is_mean=False)
                grad = tape.gradient(_loss, train_weights)
                optimizer.apply_gradients(zip(grad, train_weights))

            #debug. extreamly lower the speed. but worth doing to visualize the process. (following 4 lines can be deleted directly)
            #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32)
            #cur_samples = net(cur_z)
            #pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))
            #save_img(1, pics)
            cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim])
            cur_samples = net(cur_z)
            pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))
            save_img(2, pics)

        #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32)
        cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim])
        cur_samples = net(cur_z)
        pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28))

        save_img(batch_size, pics)
Exemple #3
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        #self.model = ConvolutionalImplicitModel(z_dim).cpu()
        self.model = BigGAN.Generator(resolution=Resolution,
                                      dim_z=z_dim,
                                      n_classes=Classes).cpu()
        self.dci_db = None

    def train(self,
              data_np,
              label_np,
              hyperparams,
              shuffle_data=True,
              path='results'):
        loss_fn = nn.MSELoss().cpu()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor

        grid_size = (5, 5)
        grid_z = torch.randn(np.prod(grid_size), self.z_dim).cpu()
        grid_y = torch.randint(low=0,
                               high=Classes,
                               size=(np.prod(grid_size), ),
                               dtype=torch.int64).cpu()

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]
            label_np = label_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):
            if epoch % 10 == 0:
                print('Saving net_weights.pth...')
                torch.save(self.model.state_dict(),
                           os.path.join(path, 'net_weights.pth'))
            if epoch % 5 == 0:
                print('Saving grid...')
                save_grid(path=path,
                          index=epoch,
                          count=np.prod(grid_size),
                          imle=self,
                          z=grid_z,
                          y=grid_y,
                          z_dim=self.z_dim)
                print('Saved')

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim))
                y_np = np.empty((num_samples * batch_size, 1), dtype=np.int64)
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size,
                                    self.z_dim,
                                    requires_grad=False).cpu()
                    y = torch.randint(low=0,
                                      high=Classes,
                                      size=(batch_size, 1),
                                      dtype=torch.int64,
                                      requires_grad=False).cpu()
                    samples = self.model(z, self.model.shared(y))
                    #import pdb; pdb.set_trace()
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    y_np[i * batch_size:(i + 1) *
                         batch_size] = y.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np, (samples_np.shape[0],
                                 np.prod(samples_np.shape[1:]))).copy()

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)
                y_np = y_np[nearest_indices]

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cpu()
                cur_y = torch.from_numpy(y_np[i * batch_size:(i + 1) *
                                              batch_size]).long().cpu()
                cur_data = torch.from_numpy(data_np[i * batch_size:(i + 1) *
                                                    batch_size]).float().cpu()
                cur_samples = self.model(cur_z, self.model.shared(cur_y))
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
Exemple #4
0
class IMLE():
    def __init__(self, z_dim, Sx_dim):
        self.z_dim = z_dim
        self.Sx_dim = Sx_dim
        self.model = ConvolutionalImplicitModel(z_dim + Sx_dim).cuda()
        self.dci_db = None

#-----------------------------------------------------------------------------------------------------------
    def train(self, data_np, data_Sx, base_lr=1e-4, batch_size=256, num_epochs=3000,\
             decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100):

        # define metric
        # loss_fn = nn.MSELoss().cuda()
        loss_fn = nn.L1Loss().cuda()
        # loss_fn = nn.BCELoss().cuda()

        self.model.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # truncate data to fit the batch size
        num_data = num_batches * batch_size
        data_np = data_np[:num_data]
        data_Sx = data_Sx[:num_data]

        #-----------------------------------------------------------------------------------------------------------
        # make empty array to store results
        samples_predict = np.empty(data_np.shape)

        samples_np = np.empty((num_samples_factor, ) + data_np.shape[1:])
        # samples_np = np.empty((num_data*num_samples_factor,)+data_np.shape[1:])

        nearest_indices = np.empty((num_data)).astype("int")

        # make global torch variables
        data_all = torch.from_numpy(data_np).float().cuda()
        Sx = torch.from_numpy(np.repeat(data_Sx, num_samples_factor,
                                        axis=0)).float().cuda()

        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

#=============================================================================================================
# train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate**(epoch // decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
# update the closest models routintely
            if epoch % staleness == 0:

                # draw random z
                z = torch.randn(num_data * num_samples_factor,
                                self.z_dim).cuda()
                #z_Sx_all = torch.cat((z, Sx), axis=1)
                z_Sx_all = torch.cat((z, Sx), axis=1)[:, :, None]

                #-----------------------------------------------------------------------------------------------------------
                # find the closest object for individual data
                nearest_indices = np.empty((num_data)).astype("int")

                for i in range(num_data):
                    samples = self.model(
                        z_Sx_all[i * num_samples_factor:(i + 1) *
                                 num_samples_factor])
                    samples_np[:] = samples.cpu().data.numpy()

                    # find the nearest neighbours
                    self.dci_db.reset()
                    self.dci_db.add(np.copy(samples_np),\
                                    num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                    nearest_indices_temp, _ = self.dci_db.query(data_np[i:i+1],\
                                        num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                    nearest_indices[i] = nearest_indices_temp[0][
                        0] + i * num_samples_factor

#-----------------------------------------------------------------------------------------------------------
# # find the closest object for individual data
# samples = self.model(z_Sx_all)
# samples_np[:] = samples.cpu().data.numpy()
#
# # find the nearest neighbours
# self.dci_db.reset()
# self.dci_db.add(np.copy(samples_np),\
#                 num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
# nearest_indices_temp, _ = self.dci_db.query(data_np,\
#                 num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
# nearest_indices[:] = nearest_indices_temp

#-----------------------------------------------------------------------------------------------------------
# restrict latent parameters to the nearest neighbour
                z_Sx = z_Sx_all[nearest_indices]

#=============================================================================================================
# gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):
                self.model.zero_grad()
                cur_samples = self.model(z_Sx[i * batch_size:(i + 1) *
                                              batch_size])

                # save the mock sample
                if (epoch + 1) % staleness == 0:
                    samples_predict[i * batch_size:(
                        i + 1) * batch_size] = cur_samples.cpu().data.numpy()

                # gradient descent
                loss = loss_fn(cur_samples,
                               data_all[i * batch_size:(i + 1) * batch_size])
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            #-----------------------------------------------------------------------------------------------------------
            # save the mock sample
            if (epoch + 1) % staleness == 0:

                # save closet models
                np.savez("../results_spectra_deconv_256x2_" + str(epoch) +  ".npz", data_np=data_np,\
                                               z_Sx_np=z_Sx.cpu().data.numpy(),\
                                               samples_np=samples_predict)

                np.savez("../mse_err_deconv_256x2_" + str(epoch) +  ".npz",\
                                                mse_err=err/num_batches)

                # save network
                torch.save(
                    self.model.state_dict(),
                    '../net_weights_spectra_deconv_256x2_epoch=' + str(epoch) +
                    '.pth')
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1).cuda()
                    samples = self.model(z)
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            err = 0.
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cuda()
                cur_data = torch.from_numpy(
                    data_np[i * batch_size:(i + 1) *
                            batch_size]).float().cuda()
                cur_samples = self.model(cur_z)
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
Exemple #6
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim
        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.dci_db = None

    def train(self, data_np, hyperparams, shuffle_data=True):
        loss_fn = nn.MSELoss().cuda()
        self.model.train()

        batch_size = hyperparams.batch_size
        num_batches = data_np.shape[0] // batch_size
        num_samples = num_batches * hyperparams.num_samples_factor  # number of generated samples

        if shuffle_data:
            data_ordering = np.random.permutation(data_np.shape[0])
            data_np = data_np[data_ordering]

        data_flat_np = np.reshape(
            data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]),
                              num_comp_indices=2,
                              num_simp_indices=7)

        for epoch in range(hyperparams.num_epochs):

            if epoch % hyperparams.decay_step == 0:
                lr = hyperparams.base_lr * hyperparams.decay_rate**(
                    epoch // hyperparams.decay_step)
                optimizer = optim.Adam(self.model.parameters(),
                                       lr=lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=1e-5)

            # Data generation step - do you re-sample training steps?
            # It does not seem like it is selecting the samples S? It always uses the original dataset data_np
            if epoch % hyperparams.staleness == 0:
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
                samples_np = np.empty((num_samples * batch_size, ) +
                                      data_np.shape[1:])
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1).cuda()
                    samples = self.model(z)
                    z_np[i * batch_size:(i + 1) *
                         batch_size] = z.cpu().data.numpy()
                    samples_np[i * batch_size:(i + 1) *
                               batch_size] = samples.cpu().data.numpy()

                samples_flat_np = np.reshape(
                    samples_np,
                    (samples_np.shape[0], np.prod(samples_np.shape[1:])))

                self.dci_db.reset()
                self.dci_db.add(samples_flat_np,
                                num_levels=2,
                                field_of_view=10,
                                prop_to_retrieve=0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np,
                                                       num_neighbours=1,
                                                       field_of_view=20,
                                                       prop_to_retrieve=0.02)
                nearest_indices = np.array(nearest_indices)[:, 0]

                z_np = z_np[nearest_indices]
                z_np += 0.01 * np.random.randn(*z_np.shape)

                del samples_np, samples_flat_np

            # z_np consists of z values whose value makes up the closest sample to each data point
            # But why do we do it this way?
            # We want to compare the real data point x_i with the closest generated point.
            # Call this point hat{x}_{sigma(i)}, and this was generated by the random noise, say, G_{theta}(z_i).
            # Once we have determined this point, how can we change G such that hat{x}_{sigma(i)} is closer to x_i?
            # We minimize || G_{theta)(z_i) - x_i||
            # What happens in the conditional case, where we condition on a random variable A?
            # The idea is to generate m different samples for each a in A, and then out of
            # m samples, find the one that is closest to the data point

            err = 0.
            # batch-gradient descent
            for i in range(num_batches):
                self.model.zero_grad()
                cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) *
                                              batch_size]).float().cuda()
                cur_data = torch.from_numpy(
                    data_np[i * batch_size:(i + 1) *
                            batch_size]).float().cuda()
                cur_samples = self.model(cur_z)
                loss = loss_fn(cur_samples, cur_data)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))
Exemple #7
0
class CoolSystem(pl.LightningModule):

    def __init__(self, z_dim, hyperparams, shuffle_data=True):
        super(CoolSystem, self).__init__()
        # not the best model...
        #self.l1 = torch.nn.Linear(28 * 28, 10)
        self.z_dim = z_dim
        self.z_grid = None
        self.hyperparams = hyperparams
        self.shuffle_data = shuffle_data
        self.dci_db = None
        self.model = ConvolutionalImplicitModel(z_dim)
        #self.model = Generator128(z_dim, n_class=10)
        #self.squeezeNet = NetPerLayer()
        self.squeezeNet = None
        self.loss_fn = nn.MSELoss()
        self._step = 0

    def regen(self, batch):
        #import pdb; pdb.set_trace()

        imgs, labels = batch
        data_np = imgs.numpy()
        #import pdb;  pdb.set_trace()
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))
        
        self.data_np = data_np
        self.data_flat_np = data_flat_np
        hyperparams = self.hyperparams
        
        batch_size = hyperparams.batch_size
        #num_batches = data_np.shape[0] // batch_size
        num_batches = 1
        num_samples = num_batches * hyperparams.num_samples_factor
        #import pdb; pdb.set_trace()
      
        z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1))
        samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:])
        for i in range(num_samples):
          z = torch.randn(batch_size, self.z_dim, 1, 1).cpu()
          samples = self.model(z, labels)
          z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy()
          samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy()
        
        samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy()
        
        if self.dci_db is None:
          self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)
        
        self.dci_db.reset()
        self.dci_db.add(samples_flat_np, num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
        nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
        nearest_indices = np.array(nearest_indices)[:,0]
        
        z_np = z_np[nearest_indices]
        z_np += 0.01*np.random.randn(*z_np.shape)
        self.z_np = z_np
        if self.z_grid is None:
          self.z_grid = self.z_np.copy() * 0.7
        
        del samples_np, samples_flat_np

        return z_np

    # def forward(self, x):
    #     return torch.relu(self.l1(x.view(x.size(0), -1)))

    def forward(self, z, class_id=None):
        cur_samples = self.model(z, class_id)
        return cur_samples

    def training_step(self, batch, batch_idx):
        # REQUIRED
        #x, y = batch
        #y_hat = self.forward(x)
        #loss = F.cross_entropy(y_hat, y)

        #batch_size = batch.shape[0]
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size
        i = batch_idx
        data_np = self.data_np

        #z_np = self.regen()
        z_np = self.z_np
        z_grid = self.z_grid
        #cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cpu()
        #imgTarget = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cpu()
        cur_z = torch.from_numpy(z_np).float().cpu()
        imgTarget = torch.from_numpy(data_np).float().cpu()
        print(cur_z.shape, batch_idx)
        def nimg(img):
          #img = img + 1
          #img = img / 2
          #import pdb; pdb.set_trace()
          img = torch.clamp(img, 0, 1)
          return img
        imgInput = batch[0]
        imgLabels = batch[1]
        self.logger.experiment.add_image('imgInput', torchvision.utils.make_grid(nimg(imgInput)), self._step)
        imgOutput = self.forward(cur_z, imgLabels)
        self.logger.experiment.add_image('imgOutput', torchvision.utils.make_grid(nimg(imgOutput)), self._step)
        if self._step % me.interp_every == 0 and self._step > -1 and True:
          cur_zgrid = torch.from_numpy(z_grid).float().cpu()
          imgGrid = self.forward(cur_zgrid, imgLabels)
          self.logger.experiment.add_image('imgGrid', torchvision.utils.make_grid(nimg(imgGrid)), self._step)
          #print('circle...')
          #import pdb; pdb.set_trace()
          imgs = me.circle_interpolation(self.model, imgInput.shape[0])
          #imgs = np.array(imgs)
          #imgs = torch.from_numpy(imgs).float().cpu()
          self.logger.experiment.add_image('imgInterp', torchvision.utils.make_grid(nimg(imgs)), self._step)
          #import pdb; pdb.set_trace()
        if self.squeezeNet is not None:
          #print('squeezeNet(imgTarget)...')
          activationsTarget = self.squeezeNet(imgTarget.repeat(1,3,1,1))
          #print('squeezeNet(imgOutput)...')
          activationsOutput = self.squeezeNet(imgOutput.repeat(1,3,1,1))
          #print('loss...')
          featLoss = None
          #for actTarget, actOutput in zip(activationsTarget[1:3], activationsOutput[1:3]):
          #for actTarget, actOutput in tqdm.tqdm(list(zip(activationsTarget, activationsOutput))):
          for actTarget, actOutput in zip(activationsTarget, activationsOutput):
            #l = F.mse_loss(actTarget, actOutput)
            #l = torch.abs(actTarget - actOutput).sum()
            #l = F.l1_loss(actTarget, actOutput)

            l = -pearsonr(actTarget.view(-1), actOutput.view(-1))
            if featLoss is None:
              featLoss = l
            else:
              featLoss += l
        else:
          featLoss = 0.0
        #pixelLoss = F.mse_loss(imgOutput, imgTarget)
        pixelLoss = self.loss_fn(imgOutput, imgTarget)
        loss = featLoss + pixelLoss
        lr = self.lr_fn(self.current_epoch)
        tensorboard_logs = {'train_loss': loss, 'lr': lr, 'epoch': self.current_epoch}
        self._step += 1
        return {'loss': loss, 'log': tensorboard_logs}

    def number_interpolation(self, count, lo, hi):
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size

        def gen_latent(pos):
          z = gen_func(pos)
          z = z.reshape([1,-1,1,1])
          return z

#     def validation_step(self, batch, batch_idx):
#         # OPTIONAL
#         x, y = batch
#         y_hat = self.forward(x)
#         return {'val_loss': F.cross_entropy(y_hat, y)}

#     def validation_end(self, outputs):
#         # OPTIONAL
#         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         tensorboard_logs = {'val_loss': avg_loss}
#         return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
#     def test_step(self, batch, batch_idx):
#         # OPTIONAL
#         x, y = batch
#         y_hat = self.forward(x)
#         return {'test_loss': F.cross_entropy(y_hat, y)}

#     def test_end(self, outputs):
#         # OPTIONAL
#         avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
#         tensorboard_logs = {'test_loss': avg_loss}
#         return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        #return torch.optim.Adam(self.parameters(), lr=0.0004)
        #epoch = 0
        hyperparams = self.hyperparams
        lr = hyperparams.base_lr # * hyperparams.decay_rate ** (epoch // hyperparams.decay_step)
        self.lr_fn = lambda epoch: hyperparams.decay_rate ** (epoch // hyperparams.decay_step)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, amsgrad=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-5)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True)
        #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.0, 0.999))
        from torch.optim.lr_scheduler import LambdaLR
        scheduler = LambdaLR(optimizer, lr_lambda=[self.lr_fn])
        return [optimizer], [scheduler]
        

    def prepare_data(self):
        #MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        #MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

        #self.data_np = np.random.randn(128, 1, 28, 28)
        # self.data_np = np.array([img.numpy() for img, label in dataset])
        # print('ready')
        # data_np = self.data_np
        # hyperparams = self.hyperparams

        # if self.shuffle_data:
        #     data_ordering = np.random.permutation(data_np.shape[0])
        #     data_np = data_np[data_ordering]

        # batch_size = hyperparams.batch_size
        # num_batches = data_np.shape[0] // batch_size
        # num_samples = num_batches * hyperparams.num_samples_factor
        
        # self.data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

    def on_batch_start(self, batch):
        self.regen(batch)
        

    def train_dataloader(self):
        # REQUIRED
        #dataset = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
        #dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
        #loader = DataLoader(dataset, batch_size=32)

        # Downloading/Louding CIFAR10 data
        trainset  = CIFAR10(root=os.getcwd(), train=True , download=False)#, transform = transform_with_aug)
        testset   = CIFAR10(root=os.getcwd(), train=False, download=False)#, transform = transform_no_aug)
        classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}

        # Separating trainset/testset data/label
        x_train  = trainset.data
        x_test   = testset.data
        y_train  = trainset.targets
        y_test   = testset.targets

        # Define a function to separate CIFAR classes by class index

        def get_class_i(x, y, i):
            """
            x: trainset.train_data or testset.test_data
            y: trainset.train_labels or testset.test_labels
            i: class label, a number between 0 to 9
            return: x_i
            """
            # Convert to a numpy array
            y = np.array(y)
            # Locate position of labels that equal to i
            pos_i = np.argwhere(y == i)
            # Convert the result into a 1-D list
            pos_i = list(pos_i[:,0])
            # Collect all data that match the desired label
            x_i = [x[j] for j in pos_i]
            
            return x_i

        class DatasetMaker(Dataset):
            def __init__(self, datasets, transformFunc = transforms.ToTensor()):
                """
                datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes
                """
                self.datasets = datasets
                self.lengths  = [len(d) for d in self.datasets]
                self.transformFunc = transformFunc
            def __getitem__(self, i):
                class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i)
                img = self.datasets[class_label][index_wrt_class]
                if self.transformFunc:
                  img = self.transformFunc(img)
                return img, class_label

            def __len__(self):
                return sum(self.lengths)
            
            def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False):
                """
                Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to.
                """
                # Which class/bin does i fall into?
                accum = np.add.accumulate(bin_sizes)
                if verbose:
                    print("accum =", accum)
                bin_index  = len(np.argwhere(accum <= absolute_index))
                if verbose:
                    print("class_label =", bin_index)
                # Which element of the fallent class/bin does i correspond to?
                index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index]
                if verbose:
                    print("index_wrt_class =", index_wrt_class)

                return bin_index, index_wrt_class

        # ================== Usage ================== #

        # Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset
        cat_dog_trainset = \
            DatasetMaker(
                [get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])]
                #transform_with_aug
            )
        cat_dog_testset  = \
            DatasetMaker(
                [get_class_i(x_test , y_test , classDict['cat']), get_class_i(x_test , y_test , classDict['dog'])]
                #transform_no_aug
            )

        kwargs = {'num_workers': 2, 'pin_memory': False}
        hyperparams = self.hyperparams
        batch_size = hyperparams.batch_size

        # Create datasetLoaders from trainset and testset
        trainsetLoader   = DataLoader(cat_dog_trainset, batch_size=batch_size, shuffle=True , **kwargs)
        testsetLoader    = DataLoader(cat_dog_testset , batch_size=batch_size, shuffle=False, **kwargs)

        #loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        # self.data_np = np.array([img.numpy() for img, label in dataset])
        #self._train_loader = loader
        #return loader
        return trainsetLoader
Exemple #8
0
class IMLE():
    def __init__(self, z_dim):
        self.z_dim = z_dim

        self.model = ConvolutionalImplicitModel(z_dim).cuda()
        self.model2 = ConvolutionalImplicitModel(z_dim).cuda()

        self.dci_db = None
        self.dci_db2 = None

#-----------------------------------------------------------------------------------------------------------
    def train(self, data_np, data_np2, base_lr=1e-3, batch_size=64, num_epochs=10000,\
              decay_step=25, decay_rate=1.0, staleness=500, num_samples_factor=100):

        # define metric
        loss_fn = nn.MSELoss().cuda()

        # make model trainable
        self.model.train()
        self.model2.train()

        # train in batch
        num_batches = data_np.shape[0] // batch_size

        # true data to mock sample
        num_samples = num_batches * num_samples_factor

#-----------------------------------------------------------------------------------------------------------
        # make it in 1D data image for DCI
        data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

        # initiate dci
        if self.dci_db is None:
            self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)
        if self.dci_db2 is None:
            self.dci_db2 = DCI(np.prod(data_np2.shape[1:]), num_comp_indices = 2, num_simp_indices = 7)

#-----------------------------------------------------------------------------------------------------------
        # train through various epochs
        for epoch in range(num_epochs):

            # decay the learning rate
            if epoch % decay_step == 0:
                lr = base_lr * decay_rate ** (epoch // decay_step)
                optimizer = optim.Adam(list(self.model.parameters()) + list(self.model2.parameters()),\
                                        lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

#-----------------------------------------------------------------------------------------------------------
            # re-evaluate the closest models routinely
            if epoch % staleness == 0:

                # initiate numpy array to store latent draws and the associate sample
                z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1, 1))

                samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:])
                samples_np2 = np.empty((num_samples * batch_size,)+data_np2.shape[1:])

                # make sample (in batch to avoid GPU memory problem)
                for i in range(num_samples):
                    z = torch.randn(batch_size, self.z_dim, 1, 1, 1).cuda()

                    samples = self.model(z)
                    samples2 = self.model2(z)

                    z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy()

                    samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy()
                    samples_np2[i*batch_size:(i+1)*batch_size] = samples2.cpu().data.numpy()

                # make 1D images
                samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:])))

#-----------------------------------------------------------------------------------------------------------
                # find the nearest neighbours
                self.dci_db.reset()
                self.dci_db.add(np.copy(samples_flat_np), num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002)
                nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02)
                nearest_indices = np.array(nearest_indices)[:,0]
                z_np = z_np[nearest_indices]

                # add random noise to the latent space to faciliate training
                z_np += 0.01*np.random.randn(*z_np.shape)

                # delete to save Hyperparameters
                del samples_np, samples_np2, samples_flat_np


#=============================================================================================================
            # permute data
            data_ordering = np.random.permutation(data_np.shape[0])

            data_np = data_np[data_ordering]
            data_np2 = data_np2[data_ordering]

            data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:])))

            z_np = z_np[data_ordering]

#-----------------------------------------------------------------------------------------------------------
            # gradient descent
            err = 0.

            # loop over all batches
            for i in range(num_batches):

                # set up backprop
                self.model.zero_grad()
                self.model2.zero_grad()

#-----------------------------------------------------------------------------------------------------------
                # evaluate the models
                cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cuda()

                cur_data = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cuda()
                cur_data2 = torch.from_numpy(data_np2[i*batch_size:(i+1)*batch_size]).float().cuda()

                cur_samples = self.model(cur_z)
                cur_samples2 = self.model2(cur_z)

#-----------------------------------------------------------------------------------------------------------
                # calculate MSE loss of the two images
                loss = loss_fn(cur_samples, cur_data) + loss_fn(cur_samples2, cur_data2)
                loss.backward()
                err += loss.item()
                optimizer.step()

            print("Epoch %d: Error: %f" % (epoch, err / num_batches))

            # save the mock sample
            if (epoch+1) % staleness == 0:
                np.savez("../results_3D.npz",
                        data_np=data_np,\
                        data_np2=data_np2,\
                        samples_np=self.model(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy(),\
                        samples_np2=self.model2(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy())