def branch_balancing(self, bundle): ''' 1. concat z_a and zeros -> decoder -> i_hat_a 2. concat z_p and zeros -> decoder -> i_hat_p 3. compute loss L(i_hat_a, i) and L(i_hat_p, i) 4. output only the i_hat where the loss is greater ''' inp, z_a, z_p = bundle if self.cut_zp: z_p = tf.stop_gradient(z_p) assert z_a.get_shape().as_list() == z_p.get_shape().as_list() zeros = tf.zeros_like(z_a) concat_a = self.concat(z_a, zeros) i_hat_a = self.decoder_model(concat_a) print('shape i_hat_a %s' % str(i_hat_a.shape)) concat_p = self.concat(zeros, z_p) i_hat_p = self.decoder_model(concat_p) print('shape i_hat_p %s' % str(i_hat_p.shape)) # losses rec_loss_a = reconstruction_loss()(inp, i_hat_a) rec_loss_p = reconstruction_loss()(inp, i_hat_p) print('shape loss_a %s' % str(rec_loss_a.shape)) print('shape loss_p %s' % str(rec_loss_p.shape)) # mix of reconstructions using a or p only depending on which has the higher loss i_hat = tf.where(rec_loss_a > rec_loss_p, i_hat_a, i_hat_p) print('shape final i_hat %s' % str(i_hat.shape)) return i_hat
def build(self): inp = Input(shape=self.input_shape) # encoders time_1 = time.time() z_a = self.appearance_encoder(inp) time_2 = time.time() z_p = self.pose_encoder(inp) time_3 = time.time() print("Build E_a %s, build E_p %s" % (time_2 - time_1, time_3 - time_2)) print(type(z_a), type(z_p)) print("Shape z_a %s" % str(z_a.shape)) # decoder concat = self.concat(z_a, z_p) print("Shape concat %s" % str(concat.shape)) i_hat = self.decoder(concat) outputs = [i_hat] outputs.extend(z_p) self.model = Model(inputs=inp, outputs=outputs) print("Outputs shape %s" % self.model.output_shape) ploss = [pose_loss()] * len(z_p) losses = [reconstruction_loss()] losses.extend(ploss) self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr)) self.model.summary()
def build(self): ''' override of build! -> first part (up to i_hat) is the same -> after i_hat we put the specific cycle thing TODO: refactor (uglyyyyyyy) Outputs for reduced cycle (in that order): - i_hat (None, 256, 256, 3) - pose (None, 17, 4) (times nblock) - concat_z_a (None, 16, 16, 256) - concat_z_p (None, 16, 16, 256) - i_hat_shuffled (None, 256, 256, 3) ''' # build everything self.build_everything() # reduced decoder inp = Input(shape=self.input_shape, name='image_input') self.log("Input shape %s" % str(inp.shape)) # encoders z_a, z_p, poses = self.call_encoders(inp) # decoder i_hat = self.concat_and_decode(z_a, z_p) i_hat = Lambda(lambda x: x, name='i_hat')( i_hat) # naming to differentiate from mixed # shuffle z_a and z_p from images from the batch and create new images concat_shuffled = self.shuffle(z_a, z_p) i_hat_mixed = self.decoder_model(concat_shuffled) i_hat_mixed = Lambda(lambda x: x, name='i_hat_mixed')(i_hat_mixed) # re-encode mixed images and get new z_a and z_p cycle_z_a = self.appearance_model(i_hat_mixed) cycle_pose_outputs = self.pose_model(i_hat_mixed) cycle_poses, cycle_z_p = self.check_pose_output(cycle_pose_outputs) # concat z_a and z_a', z_p and z_p' to have an output usable by the cycle loss concat_z_a = concatenate([z_a, cycle_z_a], name='cycle_za_concat') concat_z_p = concatenate([z_p, cycle_z_p], name='cycle_zp_concat') # build the whole model outputs = [i_hat] + poses + [concat_z_a] + [concat_z_p] + [i_hat_mixed] self.model = Model(inputs=inp, outputs=outputs) print("Outputs shape %s" % self.model.output_shape) ploss = [pose_loss()] * len(poses) losses = [reconstruction_loss()] + ploss + [ cycle_loss(), cycle_loss(), noop_loss() ] self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr)) if self.verbose: self.log("Final model summary") self.model.summary()
def get_losses_outputs(self, i_hat, poses): pose_losses = [pose_loss()] * self.n_blocks losses = [reconstruction_loss()] + pose_losses outputs = [i_hat] + poses return losses, outputs
def build(self): # self.appearance_model = self.build_appearance_model(self.input_shape) self.pose_model = self.build_pose_model(self.input_shape) print("pose model summary") self.pose_model.summary() self.decoder_model = self.build_decoder_model( (8, 8, 2048)) # i.e. 2048 for the regular model inp = Input(shape=self.input_shape) # encoders z_a = self.appearance_encoder(inp) # z_a = self.appearance_model(inp) z_p = self.pose_model(inp) print(type(z_a), type(z_p)) print("Shape z_a HELLO %s" % str(z_a.shape)) print("Shape z_p %s" % str(z_p.shape)) # decoder concat = self.concat(z_a, z_p) print("Shape concat %s" % str(concat.shape)) i_hat = self.decoder_model(concat) outputs = [i_hat, z_p] # outputs.extend(z_p) self.model = Model(inputs=inp, outputs=outputs) print("Outputs shape %s" % self.model.output_shape) # ploss = [pose_loss()] * len(z_p) ploss = [pose_loss()] * self.n_blocks losses = [reconstruction_loss()] losses.extend(ploss) self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr)) self.model.summary()
def build(self): ''' Outputs of this model is a ton of things so they can properly be used in losses. Outputs in order: - first the reconstructed image i_hat (None, 256, 256, 3) -> reconstruction loss - then n_block * pose output (None, n_joints, dim + 1) (+1 for visibility prob) -> pose loss - then the concatenated z_a and z_a' -> cycle consistency loss - then the concatenated z_p and z_p' -> cycle consistency loss - then the intermediate mixed reconstructed images, for viz -> noop loss ''' # build everything time_1 = time.time() self.appearance_model = self.build_appearance_model(self.input_shape) time_2 = time.time() self.pose_model = self.build_pose_model(self.input_shape) time_3 = time.time() self.decoder_model = self.build_decoder_model((16, 16, 2048)) # ... time_4 = time.time() print("Build E_a %s, build E_p %s, decoder D %s" % (time_2 - time_1, time_3 - time_2, time_4 - time_3)) inp = Input(shape=self.input_shape) print("Input shape %s" % str(inp.shape)) # encoders z_a = self.appearance_model(inp) assert z_a.shape.as_list() == [ None, 16, 16, 1024 ], 'wrong shape for z_a %s' % str(z_a.shape.as_list()) pose_outputs = self.pose_model(inp) poses, z_p = self.check_pose_output(pose_outputs) print("Shape z_a %s, shape z_p %s" % (str(z_a.shape), str(z_p.shape))) # decoder concat = self.concat(z_a, z_p) assert concat.shape.as_list() == [ None, 16, 16, 2048 ], 'wrong concat shape %s' % str(concat.shape) i_hat = self.decoder_model(concat) # shuffle z_a and z_p from images from the batch and create new images concat_shuffled = self.shuffle(z_a, z_p) i_hat_mixed = self.decoder_model(concat_shuffled) # re-encode mixed images and get new z_a and z_p cycle_z_a = self.appearance_model(i_hat_mixed) cycle_pose_outputs = self.pose_model(i_hat_mixed) cycle_poses, cycle_z_p = self.check_pose_output(cycle_pose_outputs) # concat z_a and z_a', z_p and z_p' to have an output usable by the cycle loss concat_z_a = concatenate([z_a, cycle_z_a]) concat_z_p = concatenate([z_p, cycle_z_p]) # build the whole model outputs = [i_hat] + poses + [concat_z_a] + [concat_z_p] + [i_hat_mixed] self.model = Model(inputs=inp, outputs=outputs) print("Outputs shape %s" % self.model.output_shape) ploss = [pose_loss()] * len(poses) losses = [reconstruction_loss()] + ploss + [ cycle_loss(), cycle_loss(), noop_loss() ] # loss = mean_squared_error self.model.compile(loss=losses, optimizer=RMSprop(lr=self.start_lr)) self.model.summary()