Exemple #1
0
class PC2Pix():
    def __init__(self,
                 ptcloud_ae=None,
                 gw=None,
                 dw=None,
                 pc_code_dim=32,
                 batch_size=64,
                 color=True,
                 gpus=1,
                 norm=False,
                 category='all'):

        self.noise_dim = 128
        self.ptcloud_ae = ptcloud_ae
        self.gw = gw
        self.dw = dw
        self.gpus = gpus
        self.pc_code_dim = pc_code_dim
        self.category = category
        self.model_dir = "saved_models"
        self.kernel_size = 3
        self.batch_size = batch_size
        self.generator = None
        self.discriminator = None
        self.adversarial = None
        os.makedirs(self.model_dir, exist_ok=True)
        os.makedirs("weights", exist_ok=True)
        self.color = color
        self.gen_spectral_normalization = False

        if color:
            # color images 128x128 rgb
            items = ['im_128', 'pc', 'elev', 'azim']
            # if big color (224 x 224) rgb
            # items = ['im', 'pc', 'elev', 'azim']
        else:
            # graycale images 224x224
            items = ['gray', 'pc', 'elev', 'azim']
        if category == 'all':
            if norm:
                path = 'all_exp_norm.json'
            else:
                path = category + '_exp.json'
        else:
            path = category + '_exp.json'
        self.split_file = os.path.join('data', path)

        self.train_source = DataSource(batch_size=self.batch_size,
                                       items=items,
                                       split_file=self.split_file)
        shapenet = self.train_source.dset
        self.epoch_datalen = len(
            shapenet.get_smids('train')) * shapenet.num_renders
        self.train_steps = self.epoch_datalen // self.batch_size

        pc_codes = "pc_codes"
        path = self.category + "-" + str(pc_code_dim) + "-pc_codes.npy"
        self.pc_codes_filename = os.path.join(pc_codes,
                                              path)  # "weights/pc_codes.npy"
        self.test_source = DataSource(batch_size=36,
                                      smids='test',
                                      items=items,
                                      nepochs=20,
                                      split_file=self.split_file)

        self.build_gan()

    def generate_fake_pc_codes(self):
        fake_pc_codes = None
        start_time = datetime.datetime.now()
        print("Generating fake pc codes...")
        steps = 4 * self.train_steps
        for i in range(steps):
            _, fake_pc, _, _ = self.train_source.next_batch()
            fake_pc = fake_pc / 0.5
            fake_pc_code = self.ptcloud_ae.encoder.predict(fake_pc)
            if fake_pc_codes is None:
                fake_pc_codes = fake_pc_code
            else:
                fake_pc_codes = np.append(fake_pc_codes, fake_pc_code, axis=0)
            elapsed_time = datetime.datetime.now() - start_time
            pcent = 100. * float(i) / steps
            log = "%0.2f%% [shape: %s] [time: %s]" % (
                pcent, fake_pc_codes.shape, elapsed_time)
            print(log)

        print("Saving pc codes to file: ", self.pc_codes_filename)
        np.save(self.pc_codes_filename, fake_pc_codes)

    def train_gan(self):
        plot_interval = 500
        save_interval = 500
        start_time = datetime.datetime.now()
        test_image, pc, test_elev_code, test_azim_code = self.test_source.next_batch(
        )
        pc = pc / 0.5
        test_pc_code = self.ptcloud_ae.encoder.predict(pc)
        noise_ = np.random.uniform(-1.0, 1.0, size=[36, self.noise_dim])
        test_image -= 0.5
        test_image /= 0.5
        ###
        test_elev_code *= 0.5
        test_elev_code += 0.5
        test_azim_code *= 0.5
        test_azim_code += 0.5
        ###
        plot_image(test_image, color=self.color)

        valid = np.ones([self.batch_size, 1])
        fake = np.zeros([self.batch_size, 1])

        valid_fake = np.concatenate((valid, fake))
        epochs = 120
        train_steps = self.train_steps * epochs

        fake_pc_codes = np.load(self.pc_codes_filename)
        fake_pc_codes_len = len(fake_pc_codes)
        print("Loaded pc codes", self.pc_codes_filename, " with len: ",
              fake_pc_codes_len)
        print("fake_pc_codes min: ", np.amin(fake_pc_codes),
              "fake_pc_codes max: ", np.amax(fake_pc_codes))
        print("test_pc_code min: ", np.amin(test_pc_code),
              " test_pc_code max: ", np.amax(test_pc_code))
        print("test_elev_code min: ", np.amin(test_elev_code),
              " test_elev_code max: ", np.amax(test_elev_code))
        print("test_azim_code min: ", np.amin(test_azim_code),
              " test_azim_code max: ", np.amax(test_azim_code))
        print("batch_size: ", self.batch_size, " pc_code_dim: ",
              self.pc_code_dim)
        print("Color images: ", self.color)

        for step in range(train_steps):
            real_image, real_pc, real_elev_code, real_azim_code = self.train_source.next_batch(
            )
            real_image -= 0.5
            real_image /= 0.5
            # pc is [-0.5, 0.5]
            real_pc = real_pc / 0.5
            real_pc_code = self.ptcloud_ae.encoder.predict(real_pc)

            rand_indexes = np.random.randint(0,
                                             fake_pc_codes_len,
                                             size=self.batch_size)
            fake_pc_code = fake_pc_codes[rand_indexes]

            pc_code = np.concatenate((real_pc_code, fake_pc_code))

            ###
            # fake_view_code = np.random.uniform(-1.0, 1.0, size=[self.batch_size, self.view_dim])
            real_elev_code *= 0.5
            real_elev_code += 0.5
            fake_elev_code = np.random.uniform(0.0,
                                               1.0,
                                               size=[self.batch_size, 1])
            real_azim_code *= 0.5
            real_azim_code += 0.5
            fake_azim_code = np.random.uniform(0.0,
                                               1.0,
                                               size=[self.batch_size, 1])
            ###

            elev_code = np.concatenate((real_elev_code, fake_elev_code))
            azim_code = np.concatenate((real_azim_code, fake_azim_code))

            noise = np.random.uniform(-1.0,
                                      1.0,
                                      size=[self.batch_size, self.noise_dim])
            fake_image = self.generator.predict(
                [noise, fake_pc_code, fake_elev_code, fake_azim_code])
            x = np.concatenate((real_image, fake_image))
            metrics = self.discriminator.train_on_batch(
                x, [valid_fake, pc_code, elev_code, azim_code])
            pcent = step * 100.0 / train_steps
            fmt = "%02.4f%%/%06d:[loss:%02.6f d:%02.6f pc:%02.6f elev:%02.6f azim:%02.6f]"
            log = fmt % (pcent, step, metrics[0], metrics[1], metrics[2],
                         metrics[3], metrics[4])

            rand_indexes = np.random.randint(0,
                                             fake_pc_codes_len,
                                             size=self.batch_size)
            fake_pc_code = fake_pc_codes[rand_indexes]

            ###
            # fake_view_code = np.random.uniform(-1.0, 1.0, size=[self.batch_size, self.view_dim])
            fake_elev_code = np.random.uniform(0.0,
                                               1.0,
                                               size=[self.batch_size, 1])
            fake_azim_code = np.random.uniform(0.0,
                                               1.0,
                                               size=[self.batch_size, 1])
            ###

            noise = np.random.uniform(-1.0,
                                      1.0,
                                      size=[self.batch_size, self.noise_dim])

            metrics = self.adversarial.train_on_batch(
                [noise, fake_pc_code, fake_elev_code, fake_azim_code],
                [valid, fake_pc_code, fake_elev_code, fake_azim_code])
            fmt = "%s [loss:%02.6f a:%02.6f pc:%02.6f elev:%02.6f azim:%02.6f]"
            log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3],
                         metrics[4])

            elapsed_time = datetime.datetime.now() - start_time
            log = "%s [time: %s]" % (log, elapsed_time)
            print(log)
            if (step + 1) % plot_interval == 0 or step == 0:
                # plot generator images on a periodic basis
                show = False
                plot_images(self.generator,
                            noise=noise_,
                            pc_code=test_pc_code,
                            elev_code=test_elev_code,
                            azim_code=test_azim_code,
                            color=self.color,
                            show=show,
                            step=(step + 1))

            if (step + 1) % save_interval == 0 or step == 0:
                # save weights on a periodic basis

                prefix = self.category + "-gen"
                if self.color:
                    prefix += "-color"
                else:
                    prefix += "-gray"
                if self.gen_spectral_normalization:
                    prefix += "-sn"
                prefix += "-" + str(self.pc_code_dim)
                fname = os.path.join("weights", prefix + ".h5")
                self.generator_single.save_weights(fname)
                prefix = self.category + "-dis"
                if self.color:
                    prefix += "-color"
                else:
                    prefix += "-gray"
                if self.gen_spectral_normalization:
                    prefix += "-sn"
                prefix += "-" + str(self.pc_code_dim)
                fname = os.path.join("weights", prefix + ".h5")
                self.discriminator_single.save_weights(fname)

    def azim_loss(self, y_true, y_pred):
        rad = 2. * np.pi
        rad *= (y_true - y_pred)
        return K.mean(K.abs(tf.atan2(K.sin(rad), K.cos(rad))), axis=-1)

    def elev_loss(self, y_true, y_pred):
        # rad = 2. * np.pi * 80. /360.
        rad = 0.4444444444444444 * np.pi
        rad *= (y_true - y_pred)
        return K.mean(K.abs(tf.atan2(K.sin(rad), K.cos(rad))), axis=-1)

    def build_gan(self):
        # set if generator is going to use spectral norm
        image, pc, elev, azim = self.train_source.next_batch()
        elev_code = Input(shape=(1, ), name='elev_code')
        azim_code = Input(shape=(1, ), name='azim_code')
        pc_code = Input(shape=(self.pc_code_dim, ), name='pc_code')
        noise_code = Input(shape=(self.noise_dim, ), name='noise_code')
        model_name = "pc2pix"
        image_size = image.shape[1]
        if self.color:
            input_shape = (image_size, image_size, 3)
        else:
            input_shape = (image_size, image_size, 1)

        inputs = Input(shape=input_shape, name='image_input')
        if self.gen_spectral_normalization:
            optimizer = Adam(lr=4e-4, beta_1=0.0, beta_2=0.9)
        else:
            optimizer = Adam(lr=2e-4, beta_1=0.5, beta_2=0.999)

        # build discriminator
        # by default, discriminator uses SN
        if self.gpus <= 1:
            self.discriminator = model.discriminator(
                input_shape, pc_code_dim=self.pc_code_dim)
            if self.dw is not None:
                print("loading discriminator weights: ", self.dw)
                self.discriminator.load_weights(self.dw)
            self.discriminator_single = self.discriminator
        else:
            with tf.device("/cpu:0"):
                self.discriminator_single = model.discriminator(
                    input_shape, pc_code_dim=self.pc_code_dim)
                if self.dw is not None:
                    print("loading discriminator weights: ", self.dw)
                    self.discriminator_single.load_weights(self.dw)

            self.discriminator = multi_gpu_model(self.discriminator_single,
                                                 gpus=self.gpus)

        loss = ['binary_crossentropy', 'mae', self.elev_loss, self.azim_loss]
        loss_weights = [1., 10., 10., 10.]
        self.discriminator.compile(loss=loss,
                                   loss_weights=loss_weights,
                                   optimizer=optimizer)
        self.discriminator_single.summary()
        path = os.path.join(self.model_dir, "discriminator.png")
        plot_model(self.discriminator_single, to_file=path, show_shapes=True)

        # build generator
        # try SN to see if mode collapse is avoided
        if self.gpus <= 1:
            self.generator = model.generator(
                input_shape,
                noise_code=noise_code,
                pc_code=pc_code,
                elev_code=elev_code,
                azim_code=azim_code,
                spectral_normalization=self.gen_spectral_normalization,
                color=self.color)
            if self.gw is not None:
                print("loading generator weights: ", self.gw)
                self.generator.load_weights(self.gw)
            self.generator_single = self.generator
        else:
            with tf.device("/cpu:0"):
                self.generator_single = model.generator(
                    input_shape,
                    noise_code=noise_code,
                    pc_code=pc_code,
                    elev_code=elev_code,
                    azim_code=azim_code,
                    spectral_normalization=self.gen_spectral_normalization,
                    color=self.color)
                if self.gw is not None:
                    print("loading generator weights: ", self.gw)
                    self.generator_single.load_weights(self.gw)

            self.generator = multi_gpu_model(self.generator_single,
                                             gpus=self.gpus)

        self.generator_single.summary()
        path = os.path.join(self.model_dir, "generator.png")
        plot_model(self.generator_single, to_file=path, show_shapes=True)

        self.discriminator.trainable = False
        if self.gen_spectral_normalization:
            optimizer = Adam(lr=1e-4, beta_1=0.0, beta_2=0.9)
        else:
            optimizer = Adam(lr=1e-4, beta_1=0.5, beta_2=0.999)

        if self.gpus <= 1:
            self.adversarial = Model(
                [noise_code, pc_code, elev_code, azim_code],
                self.discriminator(
                    self.generator([noise_code, pc_code, elev_code,
                                    azim_code])),
                name=model_name)
            self.adversarial_single = self.adversarial
        else:
            with tf.device("/cpu:0"):
                self.adversarial_single = Model(
                    [noise_code, pc_code, elev_code, azim_code],
                    self.discriminator(
                        self.generator(
                            [noise_code, pc_code, elev_code, azim_code])),
                    name=model_name)
            self.adversarial = multi_gpu_model(self.adversarial_single,
                                               gpus=self.gpus)

        self.adversarial.compile(loss=loss,
                                 loss_weights=loss_weights,
                                 optimizer=optimizer)
        self.adversarial_single.summary()
        path = os.path.join(self.model_dir, "adversarial.png")
        plot_model(self.adversarial_single, to_file=path, show_shapes=True)

        print("Using split file: ", self.split_file)
        print("1 epoch datalen: ", self.epoch_datalen)
        print("1 epoch train steps: ", self.train_steps)
        print("Using pc codes: ", self.pc_codes_filename)

    def stop_sources(self):
        self.train_source.close()
        self.test_source.close()

    def __del__(self):
        self.stop_sources()
class PtCloudStackedAE():
    def __init__(self,
                 latent_dim=32,
                 kernel_size=5,
                 lr=1e-4,
                 category="all",
                 evaluate=False,
                 emd=True):

        self.latent_dim = latent_dim
        self.lr = lr
        self.batch_size = 32
        self.evaluate = evaluate
        self.emd = emd
        self.inputs = None
        self.encoder = None
        self.decoder = None
        self.ae = None
        self.z_log_var = None
        self.z_mean = None
        self.z = None
        self.kernel_size = kernel_size
        batch_size = 32
        self.model_dir = "saved_models"
        os.makedirs(self.model_dir, exist_ok=True)
        self.category = category
        if category == 'all':
            path = 'all_exp_norm.json'
        else:
            path = category + '_exp.json'
        split_file = os.path.join('data', path)
        print("Using train split file: ", split_file)

        self.train_source = DataSource(batch_size=batch_size,
                                       split_file=split_file)
        self.test_source = DataSource(batch_size=batch_size,
                                      smids='test',
                                      nepochs=20,
                                      split_file=split_file)
        shapenet = self.train_source.dset
        self.epoch_datalen = len(
            shapenet.get_smids('train')) * shapenet.num_renders
        self.train_steps = len(shapenet.get_smids(
            'train')) * shapenet.num_renders // self.batch_size
        _, pc = self.train_source.next_batch()
        self.input_shape = pc[0].shape
        self.build_ae()

    def encoder_layer(self, x, filters, strides=1, dilation_rate=1):
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv1D(filters=filters,
                   kernel_size=self.kernel_size,
                   strides=strides,
                   dilation_rate=dilation_rate,
                   padding='same')(x)
        return x

    def compression_layer(self, x, y, maxpool=True):
        if maxpool:
            y = MaxPooling1D()(y)
        x = concatenate([x, y])

        y = Conv1D(filters=64,
                   kernel_size=1,
                   activation='relu',
                   padding='same')(x)
        return x, y

    def build_encoder(self, filters=64, activation='linear'):

        self.inputs = Input(shape=self.input_shape, name='encoder_input')
        x = self.inputs
        y = self.inputs
        strides = 2
        maxpool = True
        x1 = self.encoder_layer(x, filters, strides=1, dilation_rate=1)
        x2 = self.encoder_layer(x, filters, strides=1, dilation_rate=2)
        x4 = self.encoder_layer(x, filters, strides=1, dilation_rate=4)
        x8 = self.encoder_layer(x, filters, strides=1, dilation_rate=8)
        x = concatenate([x1, x2, x4, x8])
        x, y = self.compression_layer(x, y, maxpool=False)

        x = self.encoder_layer(x, 128, strides=2, dilation_rate=1)

        x1 = self.encoder_layer(x, filters, strides=1, dilation_rate=1)
        x2 = self.encoder_layer(x, filters, strides=1, dilation_rate=2)
        x4 = self.encoder_layer(x, filters, strides=1, dilation_rate=4)
        x8 = self.encoder_layer(x, filters, strides=1, dilation_rate=8)
        x = concatenate([x1, x2, x4, x8])
        x, y = self.compression_layer(x, y, maxpool=True)

        x = self.encoder_layer(x, 128, strides=2, dilation_rate=1)

        x1 = self.encoder_layer(x, filters, strides=1, dilation_rate=1)
        x2 = self.encoder_layer(x, filters, strides=1, dilation_rate=2)
        x4 = self.encoder_layer(x, filters, strides=1, dilation_rate=4)
        x8 = self.encoder_layer(x, filters, strides=1, dilation_rate=8)
        x = concatenate([x1, x2, x4, x8])
        x, y = self.compression_layer(x, y, maxpool=True)

        x = self.encoder_layer(x, 128, strides=2, dilation_rate=1)

        x1 = self.encoder_layer(x, filters, strides=1, dilation_rate=1)
        x2 = self.encoder_layer(x, filters, strides=1, dilation_rate=2)
        x4 = self.encoder_layer(x, filters, strides=1, dilation_rate=4)
        x8 = self.encoder_layer(x, filters, strides=1, dilation_rate=8)
        x = concatenate([x1, x2, x4, x8])
        x, y = self.compression_layer(x, y, maxpool=True)

        x = self.encoder_layer(x, 32)
        shape = K.int_shape(x)

        x = Flatten()(x)
        # x = Dense(128, activation='relu')(x)
        # experimental tanh activation, revert to none or linear if needed
        outputs = Dense(self.latent_dim,
                        activation=activation,
                        name='ae_encoder_out')(x)
        path = os.path.join(self.model_dir, "ae_encoder.png")
        self.encoder = Model(self.inputs, outputs, name='ae_encoder')

        self.encoder.summary()
        plot_model(self.encoder, to_file=path, show_shapes=True)

        return shape, filters

    def build_decoder_mlp(self, dim=1024):

        # build decoder model
        latent_inputs = Input(shape=(self.latent_dim, ), name='decoder_input')
        x = latent_inputs
        x = Dense(dim, activation='relu')(x)
        x = Dense(dim, activation='relu')(x)
        x = Dense(dim, activation='relu')(x)
        x = Dense(np.prod(self.input_shape), activation='tanh')(x)
        outputs = Reshape(self.input_shape)(x)

        path = os.path.join(self.model_dir, "decoder_mlp.png")
        # instantiate decoder model
        self.decoder = Model(latent_inputs, outputs, name='decoder')
        self.decoder.summary()
        plot_model(self.decoder, to_file=path, show_shapes=True)

    def build_decoder(self, filters, shape):

        # build decoder model
        latent_inputs = Input(shape=(self.latent_dim, ), name='decoder_input')
        pt_cloud_shape = (shape[1], shape[2])
        dim = shape[1] * shape[2]
        x = Dense(128, activation='relu')(latent_inputs)
        x = Dense(dim, activation='relu')(x)
        x = Reshape(pt_cloud_shape)(x)

        for i in range(4):
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv1D(filters=filters,
                       kernel_size=self.kernel_size,
                       padding='same')(x)
            x = UpSampling1D()(x)
            filters //= 2

        outputs = Conv1D(filters=3,
                         kernel_size=self.kernel_size,
                         activation='tanh',
                         padding='same',
                         name='decoder_output')(x)

        path = os.path.join(self.model_dir, "decoder.png")
        # instantiate decoder model
        self.decoder = Model(latent_inputs, outputs, name='decoder')
        self.decoder.summary()
        plot_model(self.decoder, to_file=path, show_shapes=True)

    def loss(self, gt, pred):
        from tf_ops.emd import tf_auctionmatch
        from tf_ops.sampling import tf_sampling
        #from tf_ops.CD import tf_nndistance
        from structural_losses import tf_nndistance
        # from structural_losses.tf_approxmatch import approx_match, match_cost

        if self.emd:
            matchl_out, matchr_out = tf_auctionmatch.auction_match(pred, gt)
            matched_out = tf_sampling.gather_point(gt, matchl_out)
            emd_loss = tf.reshape((pred - matched_out)**2,
                                  shape=(self.batch_size, -1))
            emd_loss = tf.reduce_mean(emd_loss, axis=1, keepdims=True)
            return emd_loss
        else:
            #cost_p1_p2, _, cost_p2_p1, _ = nn_distance(self.x_reconstr, self.gt)
            #self.loss = tf.reduce_mean(cost_p1_p2) + tf.reduce_mean(cost_p2_p1)

            p1top2, _, p2top1, _ = tf_nndistance.nn_distance(pred, gt)
            #p1top2 is for each element in gt, the cloest distance to this element
            # cd_loss = p1top2 + p2top1
            cd_loss = K.mean(p1top2) + K.mean(p2top1)
            # cd_loss = K.mean(cd_loss)
            return cd_loss

    def build_ae(self):
        shape, filters = self.build_encoder()
        decoder = self.build_decoder_mlp()

        outputs = self.decoder(self.encoder(self.inputs))
        self.ae = Model(self.inputs, outputs, name='ae')

        self.ae.summary()
        #if not self.evaluate:
        #    self.ae.add_loss(self.loss)
        optimizer = RMSprop(lr=self.lr)
        if not self.evaluate:
            self.ae.compile(optimizer=optimizer, loss=self.loss)
        path = os.path.join(self.model_dir, "ae.png")
        plot_model(self.ae, to_file=path, show_shapes=True)
        print("Learning rate: ", self.lr)

    def train_ae(self):
        save_interval = 500
        print_interval = 100
        start_time = datetime.datetime.now()
        loss = 0.0
        epochs = 30
        train_steps = self.train_steps * epochs

        for step in range(train_steps):
            _, pc = self.train_source.next_batch()
            pc = pc / 0.5
            metrics = self.ae.train_on_batch(x=pc, y=pc)
            loss += metrics

            if (step + 1) % print_interval == 0:
                elapsed_time = datetime.datetime.now() - start_time
                loss /= print_interval
                pcent = step * 100.0 / train_steps
                fmt = "%02.4f%%/%06d:[loss:%02.6f time:%s]"
                log = fmt % (pcent, step + 1, loss, elapsed_time)
                # log = "%d: [loss: %0.6f] [time: %s]" % (step + 1, loss, elapsed_time)
                print(log)
                loss = 0.0

            if (step + 1) % save_interval == 0:
                prefix = self.category + "-" + "pt-cloud-stacked-ae"
                if self.emd:
                    prefix += "-emd"
                else:
                    prefix += "-chamfer"
                prefix += "-" + str(self.kernel_size)
                weights_dir = "weights"
                save_weights(self.encoder,
                             "encoder",
                             weights_dir,
                             self.latent_dim,
                             prefix=prefix)
                save_weights(self.decoder,
                             "decoder",
                             weights_dir,
                             self.latent_dim,
                             prefix=prefix)
                save_weights(self.ae,
                             "ae",
                             weights_dir,
                             self.latent_dim,
                             prefix=prefix)

    def stop_sources(self):
        self.train_source.close()
        self.test_source.close()

    def __del__(self):
        self.stop_sources()