def build_dynamics_descriptors(self, name, weight):
     with tf.get_default_graph().name_scope(name):
         loss_layers = ['MSOEnet_concat/concat']
         gramians = []
         for i in range(self.input_frame_count):
             # input is in BGR [0-mean,255-mean] mean subtracted, but
             # MSOEnet accepts grayscale [0,1]
             target = tf.image.rgb_to_grayscale(
                         tf.stack(self.target_dynamic_texture[i:i+2], 1))
             output = tf.image.rgb_to_grayscale(
                 vgg_deprocess(self.output[:, i:i+2], no_clip=True,
                               unit_scale=True))
             input = [target, output]
             if i == self.input_frame_count - 1:
                 output = tf.image.rgb_to_grayscale(vgg_deprocess(
                     tf.concat([self.output[:, i:i+1],
                                self.output[:, :1]], 1),
                     no_clip=True, unit_scale=True))
                 input = [output, output]  # target not going to be used
             d = DynamicsDescriptor('dynamics_descriptor_' + str(i+1),
                                    name, tf.concat(axis=0, values=input),
                                    self.user_config['dynamics_model'])
             gramians.append([d.gramian_for_layer(l) for l in loss_layers])
         return tf.multiply(self.style_loss('dynamics_style_loss',
                                            gramians), weight)
    def minimize_callback(self, dyntex_loss, appearance_loss, dynamics_loss,
                          physics_loss, output, summaries):
        # if hasattr(self, 'current_loss'):
        #     self.past_loss = self.current_loss
        # self.current_loss = dyntex_loss
        # for cleanliness
        i = self.iterations_so_far
        snapshot_frequency = self.user_config['snapshot_frequency']
        network_out_frequency = self.user_config['network_out_frequency']
        log_frequency = self.user_config['log_frequency']
        run_id = self.user_config['run_id']

        # print training information
        self.print_info(
            [dyntex_loss, appearance_loss, dynamics_loss, physics_loss])

        if (i + 1) % snapshot_frequency == 0:
            print('Saving snapshot...')
            try:
                os.makedirs('snapshots/' + run_id)
            except OSError:
                if not os.path.isdir('snapshots/' + run_id):
                    raise
            self.saver.save(self.sess,
                            'snapshots/' + run_id + '/iter',
                            global_step=i + 1)
        if (i + 1) % log_frequency == 0:
            print('Saving log file...')
            self.summary_writer.add_summary(summaries, i + 1)
            self.summary_writer.flush()

        if (i + 1) % network_out_frequency == 0:
            print('Saving image(s)...')
            try:
                os.makedirs('data/out/' + run_id)
            except OSError:
                if not os.path.isdir('data/out/' + run_id):
                    raise
            network_out = output.reshape(
                (-1, self.input_frame_count, self.input_dimension,
                 self.input_dimension, 3))
            img_count = 1
            for out in network_out:
                frame_count = 1
                for frame in out:
                    img_out = vgg_deprocess(frame,
                                            no_clip=False,
                                            unit_scale=False)
                    filename = 'data/out/' + run_id + \
                        '/iter_%d_frame_%d_%d.png'
                    skimage.io.imsave(
                        filename % (i + 1, frame_count, img_count), img_out)
                    frame_count += 1
                img_count += 1
        self.iterations_so_far += 1