def build_dynamics_descriptors(self, name, weight): with tf.get_default_graph().name_scope(name): loss_layers = ['MSOEnet_concat/concat'] gramians = [] for i in range(self.input_frame_count): # input is in BGR [0-mean,255-mean] mean subtracted, but # MSOEnet accepts grayscale [0,1] target = tf.image.rgb_to_grayscale( tf.stack(self.target_dynamic_texture[i:i+2], 1)) output = tf.image.rgb_to_grayscale( vgg_deprocess(self.output[:, i:i+2], no_clip=True, unit_scale=True)) input = [target, output] if i == self.input_frame_count - 1: output = tf.image.rgb_to_grayscale(vgg_deprocess( tf.concat([self.output[:, i:i+1], self.output[:, :1]], 1), no_clip=True, unit_scale=True)) input = [output, output] # target not going to be used d = DynamicsDescriptor('dynamics_descriptor_' + str(i+1), name, tf.concat(axis=0, values=input), self.user_config['dynamics_model']) gramians.append([d.gramian_for_layer(l) for l in loss_layers]) return tf.multiply(self.style_loss('dynamics_style_loss', gramians), weight)
def minimize_callback(self, dyntex_loss, appearance_loss, dynamics_loss, physics_loss, output, summaries): # if hasattr(self, 'current_loss'): # self.past_loss = self.current_loss # self.current_loss = dyntex_loss # for cleanliness i = self.iterations_so_far snapshot_frequency = self.user_config['snapshot_frequency'] network_out_frequency = self.user_config['network_out_frequency'] log_frequency = self.user_config['log_frequency'] run_id = self.user_config['run_id'] # print training information self.print_info( [dyntex_loss, appearance_loss, dynamics_loss, physics_loss]) if (i + 1) % snapshot_frequency == 0: print('Saving snapshot...') try: os.makedirs('snapshots/' + run_id) except OSError: if not os.path.isdir('snapshots/' + run_id): raise self.saver.save(self.sess, 'snapshots/' + run_id + '/iter', global_step=i + 1) if (i + 1) % log_frequency == 0: print('Saving log file...') self.summary_writer.add_summary(summaries, i + 1) self.summary_writer.flush() if (i + 1) % network_out_frequency == 0: print('Saving image(s)...') try: os.makedirs('data/out/' + run_id) except OSError: if not os.path.isdir('data/out/' + run_id): raise network_out = output.reshape( (-1, self.input_frame_count, self.input_dimension, self.input_dimension, 3)) img_count = 1 for out in network_out: frame_count = 1 for frame in out: img_out = vgg_deprocess(frame, no_clip=False, unit_scale=False) filename = 'data/out/' + run_id + \ '/iter_%d_frame_%d_%d.png' skimage.io.imsave( filename % (i + 1, frame_count, img_count), img_out) frame_count += 1 img_count += 1 self.iterations_so_far += 1