Ejemplo n.º 1
0
    def branch_balancing(self, bundle):
        '''
        1. concat z_a and zeros -> decoder -> i_hat_a
        2. concat z_p and zeros -> decoder -> i_hat_p
        3. compute loss L(i_hat_a, i) and L(i_hat_p, i)
        4. output only the i_hat where the loss is greater
        '''
        inp, z_a, z_p = bundle
        if self.cut_zp:
            z_p = tf.stop_gradient(z_p)

        assert z_a.get_shape().as_list() == z_p.get_shape().as_list()
        zeros = tf.zeros_like(z_a)

        concat_a = self.concat(z_a, zeros)
        i_hat_a = self.decoder_model(concat_a)
        print('shape i_hat_a %s' % str(i_hat_a.shape))

        concat_p = self.concat(zeros, z_p)
        i_hat_p = self.decoder_model(concat_p)
        print('shape i_hat_p %s' % str(i_hat_p.shape))

        # losses
        rec_loss_a = reconstruction_loss()(inp, i_hat_a)
        rec_loss_p = reconstruction_loss()(inp, i_hat_p)
        print('shape loss_a %s' % str(rec_loss_a.shape))
        print('shape loss_p %s' % str(rec_loss_p.shape))

        # mix of reconstructions using a or p only depending on which has the higher loss
        i_hat = tf.where(rec_loss_a > rec_loss_p, i_hat_a, i_hat_p)
        print('shape final i_hat %s' % str(i_hat.shape))
        return i_hat
Ejemplo n.º 2
0
    def build(self):
        inp = Input(shape=self.input_shape)

        # encoders
        time_1 = time.time()
        z_a = self.appearance_encoder(inp)
        time_2 = time.time()
        z_p = self.pose_encoder(inp)
        time_3 = time.time()

        print("Build E_a %s, build E_p %s" %
              (time_2 - time_1, time_3 - time_2))
        print(type(z_a), type(z_p))
        print("Shape z_a %s" % str(z_a.shape))

        # decoder
        concat = self.concat(z_a, z_p)
        print("Shape concat %s" % str(concat.shape))
        i_hat = self.decoder(concat)

        outputs = [i_hat]
        outputs.extend(z_p)
        self.model = Model(inputs=inp, outputs=outputs)
        print("Outputs shape %s" % self.model.output_shape)

        ploss = [pose_loss()] * len(z_p)
        losses = [reconstruction_loss()]
        losses.extend(ploss)

        self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr))
        self.model.summary()
Ejemplo n.º 3
0
    def build(self):
        '''
        override of build!
        -> first part (up to i_hat) is the same
        -> after i_hat we put the specific cycle thing
        
        TODO: refactor (uglyyyyyyy)
        
        Outputs for reduced cycle (in that order):
        - i_hat (None, 256, 256, 3)
        - pose (None, 17, 4)   (times nblock)
        - concat_z_a (None, 16, 16, 256)
        - concat_z_p (None, 16, 16, 256)
        - i_hat_shuffled (None, 256, 256, 3)
        '''

        # build everything
        self.build_everything()  # reduced decoder

        inp = Input(shape=self.input_shape, name='image_input')
        self.log("Input shape %s" % str(inp.shape))

        # encoders
        z_a, z_p, poses = self.call_encoders(inp)

        # decoder
        i_hat = self.concat_and_decode(z_a, z_p)
        i_hat = Lambda(lambda x: x, name='i_hat')(
            i_hat)  # naming to differentiate from mixed

        # shuffle z_a and z_p from images from the batch and create new images
        concat_shuffled = self.shuffle(z_a, z_p)
        i_hat_mixed = self.decoder_model(concat_shuffled)
        i_hat_mixed = Lambda(lambda x: x, name='i_hat_mixed')(i_hat_mixed)

        # re-encode mixed images and get new z_a and z_p
        cycle_z_a = self.appearance_model(i_hat_mixed)
        cycle_pose_outputs = self.pose_model(i_hat_mixed)
        cycle_poses, cycle_z_p = self.check_pose_output(cycle_pose_outputs)

        # concat z_a and z_a', z_p and z_p' to have an output usable by the cycle loss
        concat_z_a = concatenate([z_a, cycle_z_a], name='cycle_za_concat')
        concat_z_p = concatenate([z_p, cycle_z_p], name='cycle_zp_concat')

        # build the whole model
        outputs = [i_hat] + poses + [concat_z_a] + [concat_z_p] + [i_hat_mixed]
        self.model = Model(inputs=inp, outputs=outputs)
        print("Outputs shape %s" % self.model.output_shape)

        ploss = [pose_loss()] * len(poses)
        losses = [reconstruction_loss()] + ploss + [
            cycle_loss(), cycle_loss(),
            noop_loss()
        ]

        self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr))

        if self.verbose:
            self.log("Final model summary")
            self.model.summary()
Ejemplo n.º 4
0
    def get_losses_outputs(self, i_hat, poses):
        pose_losses = [pose_loss()] * self.n_blocks
        losses = [reconstruction_loss()] + pose_losses

        outputs = [i_hat] + poses

        return losses, outputs
Ejemplo n.º 5
0
    def build(self):
        # self.appearance_model = self.build_appearance_model(self.input_shape)
        self.pose_model = self.build_pose_model(self.input_shape)
        print("pose model summary")
        self.pose_model.summary()
        self.decoder_model = self.build_decoder_model(
            (8, 8, 2048))  # i.e. 2048 for the regular model

        inp = Input(shape=self.input_shape)

        # encoders
        z_a = self.appearance_encoder(inp)
        # z_a = self.appearance_model(inp)
        z_p = self.pose_model(inp)

        print(type(z_a), type(z_p))
        print("Shape z_a HELLO %s" % str(z_a.shape))
        print("Shape z_p %s" % str(z_p.shape))

        # decoder
        concat = self.concat(z_a, z_p)
        print("Shape concat %s" % str(concat.shape))
        i_hat = self.decoder_model(concat)

        outputs = [i_hat, z_p]
        # outputs.extend(z_p)
        self.model = Model(inputs=inp, outputs=outputs)
        print("Outputs shape %s" % self.model.output_shape)

        # ploss = [pose_loss()] * len(z_p)
        ploss = [pose_loss()] * self.n_blocks
        losses = [reconstruction_loss()]
        losses.extend(ploss)

        self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr))
        self.model.summary()
Ejemplo n.º 6
0
    def build(self):
        '''
        Outputs of this model is a ton of things so they can properly be used in losses. 
        Outputs in order:
        - first the reconstructed image i_hat (None, 256, 256, 3) -> reconstruction loss
        - then n_block * pose output (None, n_joints, dim + 1) (+1 for visibility prob) -> pose loss
        - then the concatenated z_a and z_a' -> cycle consistency loss
        - then the concatenated z_p and z_p' -> cycle consistency loss
        - then the intermediate mixed reconstructed images, for viz -> noop loss      
        '''

        # build everything
        time_1 = time.time()
        self.appearance_model = self.build_appearance_model(self.input_shape)
        time_2 = time.time()
        self.pose_model = self.build_pose_model(self.input_shape)
        time_3 = time.time()
        self.decoder_model = self.build_decoder_model((16, 16, 2048))  # ...
        time_4 = time.time()

        print("Build E_a %s, build E_p %s, decoder D %s" %
              (time_2 - time_1, time_3 - time_2, time_4 - time_3))

        inp = Input(shape=self.input_shape)
        print("Input shape %s" % str(inp.shape))

        # encoders
        z_a = self.appearance_model(inp)
        assert z_a.shape.as_list() == [
            None, 16, 16, 1024
        ], 'wrong shape for z_a %s' % str(z_a.shape.as_list())
        pose_outputs = self.pose_model(inp)

        poses, z_p = self.check_pose_output(pose_outputs)
        print("Shape z_a %s, shape z_p %s" % (str(z_a.shape), str(z_p.shape)))

        # decoder
        concat = self.concat(z_a, z_p)
        assert concat.shape.as_list() == [
            None, 16, 16, 2048
        ], 'wrong concat shape %s' % str(concat.shape)
        i_hat = self.decoder_model(concat)

        # shuffle z_a and z_p from images from the batch and create new images
        concat_shuffled = self.shuffle(z_a, z_p)
        i_hat_mixed = self.decoder_model(concat_shuffled)

        # re-encode mixed images and get new z_a and z_p
        cycle_z_a = self.appearance_model(i_hat_mixed)
        cycle_pose_outputs = self.pose_model(i_hat_mixed)
        cycle_poses, cycle_z_p = self.check_pose_output(cycle_pose_outputs)

        # concat z_a and z_a', z_p and z_p' to have an output usable by the cycle loss
        concat_z_a = concatenate([z_a, cycle_z_a])
        concat_z_p = concatenate([z_p, cycle_z_p])

        # build the whole model
        outputs = [i_hat] + poses + [concat_z_a] + [concat_z_p] + [i_hat_mixed]
        self.model = Model(inputs=inp, outputs=outputs)
        print("Outputs shape %s" % self.model.output_shape)

        ploss = [pose_loss()] * len(poses)
        losses = [reconstruction_loss()] + ploss + [
            cycle_loss(), cycle_loss(),
            noop_loss()
        ]
        # loss = mean_squared_error
        self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr))
        self.model.summary()