Ejemplo n.º 1
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('discriminator'):
            ##
            # Setup scale networks. Each will make the predictions for images at a given scale.
            ##

            self.scale_nets = []
            #for scale_num in xrange(self.num_scale_nets):
            for scale_num in range(self.num_scale_nets):
                with tf.name_scope('scale_net_' + str(scale_num)):
                    scale_factor = 1. / 2**(
                        (self.num_scale_nets - 1) - scale_num)
                    self.scale_nets.append(
                        DScaleModel(scale_num, int(self.height * scale_factor),
                                    int(self.width * scale_factor),
                                    self.scale_conv_layer_fms[scale_num],
                                    self.scale_kernel_sizes[scale_num],
                                    self.scale_fc_layer_sizes[scale_num]))

            # A list of the prediction tensors for each scale network
            self.scale_preds = []
            #for scale_num in xrange(self.num_scale_nets):
            for scale_num in range(self.num_scale_nets):
                self.scale_preds.append(self.scale_nets[scale_num].preds)

            ##
            # Data
            ##

            self.labels = tf.placeholder(tf.float32,
                                         shape=[None, 1],
                                         name='labels')

            ##
            # Training
            ##

            with tf.name_scope('training'):
                # global loss is the combined loss from every scale network
                self.global_loss = adv_loss(self.scale_preds, self.labels)
                self.global_step = tf.Variable(0,
                                               trainable=False,
                                               name='global_step')
                self.optimizer = tf.train.GradientDescentOptimizer(
                    c.LRATE_D, name='optimizer')
                self.train_op = self.optimizer.minimize(
                    self.global_loss,
                    global_step=self.global_step,
                    name='train_op')

                # add summaries to visualize in TensorBoard
                #loss_summary = tf.scalar_summary('loss_D', self.global_loss)
                loss_summary = tf.summary.scalar('loss_D', self.global_loss)
                self.summaries = tf.summary.merge([loss_summary])
    def define_graph(self, generator):
        """
        Sets up the model graph in TensorFlow.

        @param generator: The generator model that generates frames for this to discriminate.
        """

        with tf.name_scope('discriminator'):
            ##
            # Data
            ##

            self.input_clips = tf.placeholder(
                tf.float32, shape=[None, self.height, self.width, (c.HIST_LEN + c.GT_LEN) * 3])

            self.g_input_frames = self.input_clips[:, :, :, :c.HIST_LEN * 3]
            self.gt_frames = self.input_clips[:, :, :, c.HIST_LEN * 3:]
            input_shape = tf.shape(self.g_input_frames)
            batch_size = input_shape[0]

            ##
            # Get Generator Frames
            ##

            with tf.name_scope('gen_frames'):
                self.g_scale_preds = []  # the generated images at each scale
                self.scale_gts = []  # the ground truth images at each scale
                self.resized_inputs = []  # the resized input images at each scale

                for scale_num in xrange(self.num_scale_nets):
                    with tf.name_scope('scale_' + str(scale_num)):
                        # for all scales but the first, add the frame generated by the last
                        # scale to the input

                        if scale_num > 0:
                            last_scale_pred = self.g_scale_preds[scale_num - 1]
                        else:
                            last_scale_pred = None

                        # calculate
                        train_preds, train_gts = generator.generate_predictions(scale_num,
                                                                                self.height,
                                                                                self.width,
                                                                                self.g_input_frames,
                                                                                self.gt_frames,
                                                                                last_scale_pred,
                                                                                'test')

                        input_scale_factor = 1. / self.inverse_scale_factor[scale_num]
                        input_scale_height = int(self.height * input_scale_factor)
                        input_scale_width = int(self.width * input_scale_factor)
                        resized_inputs = tf.image.resize_images(self.g_input_frames,
                                                                [input_scale_height, input_scale_width])

                        self.g_scale_preds.append(train_preds)
                        self.scale_gts.append(train_gts)
                        self.resized_inputs.append(resized_inputs)

            # concatenate the generated images and ground truths at each scale
            self.scale_inputs = []
            for scale_num in xrange(self.num_scale_nets):
                self.scale_inputs.append(
                    tf.concat([self.g_scale_preds[scale_num], self.scale_gts[scale_num]], 0))

            # create the labels
            self.labels = tf.concat([tf.zeros([batch_size, 1]), tf.ones([batch_size, 1])], 0)

            ##
            # Calculation
            ##

            # A list of the prediction tensors for each scale network
            self.scale_preds = []

            for scale_num in xrange(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('calculation'):
                        # get predictions from the scale network
                        self.scale_preds.append(
                            self.scale_nets[scale_num].generate_all_predictions(
                                tf.concat([self.resized_inputs[scale_num], self.resized_inputs[scale_num]], 0),
                                self.scale_inputs[scale_num]))

            ##
            # Training
            ##

            with tf.name_scope('training'):
                # global loss is the combined loss from every scale network
                self.global_loss = adv_loss(self.scale_preds, self.labels)

                with tf.name_scope('train_step'):
                    self.global_step = tf.Variable(0, trainable=False, name='global_step')
                    self.optimizer = tf.train.GradientDescentOptimizer(self.lrate, name='optimizer')
                    self.train_op = self.optimizer.minimize(self.global_loss, var_list=self.train_vars, name='train_op',
                                                            global_step=self.global_step)

                    # Accuracy test
                    all_preds = tf.stack(self.scale_preds)
                    self.accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(all_preds), self.labels),
                                                           tf.int32))

                    # add summaries to visualize in TensorBoard
                    loss_summary = tf.summary.scalar('loss_D', self.global_loss)
                    accuracy_summary = tf.summary.scalar('accuracy_D', self.accuracy)
                    self.summaries = tf.summary.merge([loss_summary, accuracy_summary])
Ejemplo n.º 3
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('data'):
                self.input_frames_train = tf.placeholder(tf.float32,
                                                         shape=[
                                                             None,
                                                             self.height_train,
                                                             self.width_train,
                                                             3 * c.HIST_LEN
                                                         ])
                self.gt_frames_train = tf.placeholder(
                    tf.float32,
                    shape=[None, self.height_train, self.width_train, 3])

                self.input_frames_test = tf.placeholder(tf.float32,
                                                        shape=[
                                                            None,
                                                            self.height_test,
                                                            self.width_test,
                                                            3 * c.HIST_LEN
                                                        ])
                self.gt_frames_test = tf.placeholder(
                    tf.float32,
                    shape=[None, self.height_test, self.width_test, 3])

                # use variable batch_size for more flexibility
                self.batch_size_train = tf.shape(self.input_frames_train)[0]
                self.batch_size_test = tf.shape(self.input_frames_test)[0]

            ##
            # Scale network setup and calculation
            ##

            self.summaries_train = []
            self.scale_preds_train = []  # the generated images at each scale
            self.scale_gts_train = []  # the ground truth images at each scale
            self.d_scale_preds = [
            ]  # the predictions from the discriminator model

            self.summaries_test = []
            self.scale_preds_test = []  # the generated images at each scale
            self.scale_gts_test = []  # the ground truth images at each scale
            self.scale_lastinputs_train = []

            for scale_num in xrange(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        ws = []
                        bs = []

                        # create weights for kernels
                        for i in xrange(len(
                                self.scale_kernel_sizes[scale_num])):
                            ws.append(
                                w([
                                    self.scale_kernel_sizes[scale_num][i],
                                    self.scale_kernel_sizes[scale_num][i],
                                    self.scale_layer_fms[scale_num][i],
                                    self.scale_layer_fms[scale_num][i + 1]
                                ]))
                            bs.append(
                                b([self.scale_layer_fms[scale_num][i + 1]]))

                    with tf.name_scope('calculation'):

                        def calculate(height, width, inputs, gts,
                                      last_gen_frames):
                            # scale inputs and gts
                            scale_factor = 1. / 2**(
                                (self.num_scale_nets - 1) - scale_num)
                            scale_height = int(height * scale_factor)
                            scale_width = int(width * scale_factor)

                            inputs = tf.image.resize_images(
                                inputs, [scale_height, scale_width])
                            scale_last_inputs = inputs[:, :, :, -3:]
                            scale_gts = tf.image.resize_images(
                                gts, [scale_height, scale_width])

                            # for all scales but the first, add the frame generated by the last
                            # scale to the input
                            if scale_num > 0:
                                last_gen_frames = tf.image.resize_images(
                                    last_gen_frames,
                                    [scale_height, scale_width])
                                inputs = tf.concat([inputs, last_gen_frames],
                                                   3)

                            # generated frame predictions
                            preds = inputs

                            # perform convolutions
                            with tf.name_scope('convolutions'):
                                for i in xrange(
                                        len(self.scale_kernel_sizes[scale_num])
                                ):
                                    # Convolve layer
                                    preds = tf.nn.conv2d(preds,
                                                         ws[i], [1, 1, 1, 1],
                                                         padding=c.PADDING_G)

                                    # Activate with ReLU (or Tanh for last layer)
                                    if i == len(
                                            self.scale_kernel_sizes[scale_num]
                                    ) - 1:
                                        preds = tf.nn.tanh(preds + bs[i])
                                    else:
                                        preds = tf.nn.relu(preds + bs[i])

                            return preds, scale_gts, scale_last_inputs

                        ##
                        # Perform train calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_train = self.scale_preds_train[
                                scale_num - 1]
                        else:
                            last_scale_pred_train = None

                        # calculate
                        train_preds, train_gts, train_last_inputs = calculate(
                            self.height_train, self.width_train,
                            self.input_frames_train, self.gt_frames_train,
                            last_scale_pred_train)
                        self.scale_preds_train.append(train_preds)
                        #self.buffer.append(train_preds)
                        self.scale_gts_train.append(train_gts)
                        self.scale_lastinputs_train.append(train_last_inputs)

                        # We need to run the network first to get generated frames, run the
                        # discriminator on those frames to get d_scale_preds, then run this
                        # again for the loss optimization.
                        if c.ADVERSARIAL:
                            self.d_scale_preds.append(
                                tf.placeholder(tf.float32, [None, 1]))

                        ##
                        # Perform test calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_test = self.scale_preds_test[
                                scale_num - 1]
                        else:
                            last_scale_pred_test = None

                        # calculate
                        test_preds, test_gts, test_last_inputs = calculate(
                            self.height_test, self.width_test,
                            self.input_frames_test, self.gt_frames_test,
                            last_scale_pred_test)
                        self.scale_preds_test.append(test_preds)
                        self.scale_gts_test.append(test_gts)

            ##
            # Training
            ##

            with tf.name_scope('train'):
                # global loss is the combined loss from every scale network
                self.global_loss = combined_loss(
                    self.scale_preds_train, self.scale_gts_train,
                    self.d_scale_preds, self.scale_lastinputs_train, c.L_NUM,
                    c.ALPHA_NUM, c.LAM_ADV, c.LAM_LP, c.LAM_GDL, c.LAM_TV)
                self.l1_loss = c.LAM_LP * lp_loss(self.scale_preds_train,
                                                  self.scale_gts_train, 1)
                batch_size = tf.shape(self.scale_preds_train[0])[0]
                self.adv_loss = c.LAM_ADV * adv_loss(
                    self.d_scale_preds, tf.ones([4 * batch_size, 1]))
                self.gdl_loss = c.LAM_GDL * gdl_loss(self.scale_preds_train,
                                                     self.scale_gts_train, 1)
                self.tv_loss = c.LAM_TV * tv_loss(self.scale_preds_train)

                self.global_step = tf.Variable(0, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(
                    learning_rate=c.LRATE_G, name='optimizer')
                self.train_op = self.optimizer.minimize(
                    self.global_loss,
                    global_step=self.global_step,
                    name='train_op')

                # train loss summary
                loss_summary = tf.summary.scalar('train_loss_G',
                                                 self.global_loss)
                self.summaries_train.append(loss_summary)
Ejemplo n.º 4
0
    def train_step(self, batch, loss_list, discriminator=None):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [c.BATCH_SIZE x self.height x self.width x (3 * (c.HIST_LEN + 1))].
                      The input and output frames, concatenated along the channel axis (index 3).
        @param discriminator: The discriminator model. Default = None, if not adversarial.

        @return: The global step.
        """
        ##
        # Split into inputs and outputs
        ##

        input_frames = batch[:, :, :, :-3]
        gt_frames = batch[:, :, :, -3:]

        ##
        # Train
        ##

        feed_dict = {
            self.input_frames_train: input_frames,
            self.gt_frames_train: gt_frames
        }

        if c.ADVERSARIAL:
            # Run the generator first to get generated frames
            scale_preds = self.sess.run(self.scale_preds_train,
                                        feed_dict=feed_dict)

            # Run the discriminator nets on those frames to get predictions
            d_feed_dict = {}
            for scale_num, gen_frames in enumerate(scale_preds):
                d_feed_dict[discriminator.scale_nets[scale_num].
                            input_frames] = gen_frames
            d_scale_preds = self.sess.run(discriminator.scale_preds,
                                          feed_dict=d_feed_dict)

            d_scale_preds_list = []
            # Add discriminator predictions to the
            for i, preds in enumerate(d_scale_preds):
                d_scale_preds_list.append(preds)

            # Calculate current loss from adversary
            batch_size = len(d_scale_preds_list[0])
            discriminator_curr_loss = adv_loss(d_scale_preds_list,
                                               tf.ones([batch_size, 1]))
            loss_value = self.sess.run(discriminator_curr_loss)
            print(loss_value)
            loss_list.append(loss_value)

            feed_dict[self.adv_windowed_loss] = sum(loss_list) / len(loss_list)

        _, global_loss, global_psnr_error, global_sharpdiff_error, global_step, summaries = \
            self.sess.run([self.train_op,
                           self.global_loss,
                           self.psnr_error_train,
                           self.sharpdiff_error_train,
                           self.global_step,
                           self.summaries_train],
                          feed_dict=feed_dict)

        ##
        # User output
        ##
        if global_step % c.STATS_FREQ == 0:
            print 'GeneratorModel : Step ', global_step
            print '                 Global Loss    : ', global_loss
            print '                 PSNR Error     : ', global_psnr_error
            print '                 Sharpdiff Error: ', global_sharpdiff_error
        if global_step % c.SUMMARY_FREQ == 0:
            self.summary_writer.add_summary(summaries, global_step)
            print 'GeneratorModel: saved summaries'
        if global_step % c.IMG_SAVE_FREQ == 0:
            print '-' * 30
            print 'Saving images...'

            # if not adversarial, we didn't get the preds for each scale net before for the
            # discriminator prediction, so do it now
            if not c.ADVERSARIAL:
                scale_preds = self.sess.run(self.scale_preds_train,
                                            feed_dict=feed_dict)

            # re-generate scale gt_frames to avoid having to run through TensorFlow.
            scale_gts = []
            for scale_num in xrange(self.num_scale_nets):
                scale_factor = 1. / 2**((self.num_scale_nets - 1) - scale_num)
                scale_height = int(self.height_train * scale_factor)
                scale_width = int(self.width_train * scale_factor)

                # resize gt_output_frames for scale and append to scale_gts_train
                scaled_gt_frames = np.empty(
                    [c.BATCH_SIZE, scale_height, scale_width, 3])
                for i, img in enumerate(gt_frames):
                    # for skimage.transform.resize, images need to be in range [0, 1], so normalize
                    # to [0, 1] before resize and back to [-1, 1] after
                    sknorm_img = (img / 2) + 0.5
                    resized_frame = resize(sknorm_img,
                                           [scale_height, scale_width, 3])
                    scaled_gt_frames[i] = (resized_frame - 0.5) * 2
                scale_gts.append(scaled_gt_frames)

            # for every clip in the batch, save the inputs, scale preds and scale gts
            for pred_num in xrange(len(input_frames)):
                pred_dir = c.get_dir(
                    os.path.join(c.IMG_SAVE_DIR, 'Step_' + str(global_step),
                                 str(pred_num)))

                # save input images
                for frame_num in xrange(c.HIST_LEN):
                    img = input_frames[pred_num, :, :,
                                       (frame_num * 3):((frame_num + 1) * 3)]
                    imsave(
                        os.path.join(pred_dir,
                                     'input_' + str(frame_num) + '.png'), img)

                # save preds and gts at each scale
                # noinspection PyUnboundLocalVariable
                for scale_num, scale_pred in enumerate(scale_preds):
                    gen_img = scale_pred[pred_num]

                    path = os.path.join(pred_dir, 'scale' + str(scale_num))
                    gt_img = scale_gts[scale_num][pred_num]

                    imsave(path + '_gen.png', gen_img)
                    imsave(path + '_gt.png', gt_img)

            print 'Saved images!'
            print '-' * 30

        return global_step, loss_list
Ejemplo n.º 5
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('discriminator'):
            ##
            # Setup scale networks. Each will make the predictions for images at a given scale.
            ##

            self.scale_nets = []
            for scale_num in range(self.num_scale_nets):
                with tf.name_scope('scale_net_' + str(scale_num)):
                    scale_factor = 1. / 2**(
                        (self.num_scale_nets - 1) - scale_num)
                    self.scale_nets.append(
                        DScaleModel(scale_num, int(self.height * scale_factor),
                                    int(self.width * scale_factor),
                                    self.scale_conv_layer_fms[scale_num],
                                    self.scale_kernel_sizes[scale_num],
                                    self.scale_fc_layer_sizes[scale_num],
                                    self.is_w))

            # A list of the prediction tensors for each scale network
            self.scale_preds = []
            for scale_num in range(self.num_scale_nets):
                self.scale_preds.append(self.scale_nets[scale_num].preds)

            ##
            # Data
            ##

            self.labels = tf.placeholder(tf.float32,
                                         shape=[None, 1],
                                         name='labels')

            ##
            # Training
            ##

            with tf.name_scope('training'):
                # global loss is the combined loss from every scale network

                self.global_loss = adv_loss(self.scale_preds, self.labels,
                                            self.is_w)
                self.global_step = tf.Variable(0.0,
                                               trainable=False,
                                               name='global_step')

                with tf.control_dependencies(
                        tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                    if self.is_w == True:
                        self.optimizer = tf.train.RMSPropOptimizer(
                            self.c.LRATE_D, name='optimizer')
                    else:
                        self.optimizer = tf.train.GradientDescentOptimizer(
                            self.c.LRATE_D, name='optimizer')
                    # self.optimizer = tf.train.GradientDescentOptimizer(self.c.LRATE_D, name='optimizer')
                    self.train_op = self.optimizer.minimize(
                        self.global_loss,
                        global_step=self.global_step,
                        name='train_op',
                        var_list=self.discriminator_vars)

                # add summaries to visualize in TensorBoard
                self.clip = [
                    v.assign(tf.clip_by_value(v, -0.01, 0.01))
                    for v in self.discriminator_vars
                ]
                loss_summary = tf.summary.scalar('loss_D', self.global_loss)
                self.summaries = tf.summary.merge([loss_summary])
Ejemplo n.º 6
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('discriminator'):
            ##
            # Setup scale networks. Each will make the predictions for images at a given scale.
            ##

            self.scale_nets = []
            for scale_num in range(self.num_scale_nets):
                with tf.name_scope('scale_net_' + str(scale_num)):
                    scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num)
                    self.scale_nets.append(DScaleModel(scale_num,
                                                       int(self.height * scale_factor),
                                                       int(self.width * scale_factor),
                                                       self.scale_conv_layer_fms[scale_num],
                                                       self.scale_kernel_sizes[scale_num],
                                                       self.scale_fc_layer_sizes[scale_num]))

            # A list of the prediction tensors for each scale network
            self.scale_preds = []
            for scale_num in range(self.num_scale_nets):
                self.scale_preds.append(self.scale_nets[scale_num].preds)

            ##
            # Data
            ##

            self.labels = tf.placeholder(tf.float32, shape=[None, 1], name='labels')

            ##
            # Training
            ##

            with tf.name_scope('training'):
                # global loss is the combined loss from every scale network
                self.global_loss = adv_loss(self.scale_preds, self.labels)

                if c.WASSERSTEIN and c.W_GP:
                    epsilon = tf.random_uniform([], 0.0, 1.0)
                    grad_penality = []
                    for scale_net in self.scale_nets:
                        fake, real = tf.split(scale_net.input_frames,2)
                        self.x_hat = real * epsilon + (1 - epsilon) * fake
                        self.d_hat = scale_net.generate_predictions(self.x_hat)
                        grad_penality.append(grad_penality_loss(self.x_hat, self.d_hat))
                    self.global_loss += c.LAM_GP * tf.reduce_mean(grad_penality)

                self.global_step = tf.Variable(0, trainable=False, name='global_step')
                self.optimizer = tf.train.AdamOptimizer(c.LRATE_D, name='optimizer')
                self.train_op_ = self.optimizer.minimize(self.global_loss,
                                                        global_step=self.global_step,
                                                        name='train_op')

                # Clipping to enforce 1-Lipschitz function
                if c.WASSERSTEIN and not c.W_GP:
                    with tf.control_dependencies([self.train_op_]):
                        self.train_op = tf.group(*(tf.assign(var, tf.clip_by_value(var, -c.W_Clip, c.W_Clip)) for var in tf.trainable_variables() if 'discriminator' in var.name))
                else:
                    self.train_op = self.train_op_

                # add summaries to visualize in TensorBoard
                loss_summary = tf.summary.scalar('loss_D', self.global_loss)
                self.summaries = tf.summary.merge([loss_summary])