Beispiel #1
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('data'):
                self.input_frames_train = tf.placeholder(
                    tf.float32, shape=[None, self.height_train, self.width_train, 3 * c.HIST_LEN])
                self.gt_frames_train = tf.placeholder(
                    tf.float32, shape=[None, self.height_train, self.width_train, 3])

                self.input_frames_test = tf.placeholder(
                    tf.float32, shape=[None, self.height_test, self.width_test, 3 * c.HIST_LEN])
                self.gt_frames_test = tf.placeholder(
                    tf.float32, shape=[None, self.height_test, self.width_test, 3])

                # use variable batch_size for more flexibility
                self.batch_size_train = tf.shape(self.input_frames_train)[0]
                self.batch_size_test = tf.shape(self.input_frames_test)[0]

            ##
            # Scale network setup and calculation
            ##

            self.summaries_train = []
            self.scale_preds_train = []  # the generated images at each scale
            self.scale_gts_train = []  # the ground truth images at each scale
            self.d_scale_preds = []  # the predictions from the discriminator model

            self.summaries_test = []
            self.scale_preds_test = []  # the generated images at each scale
            self.scale_gts_test = []  # the ground truth images at each scale

            for scale_num in range(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        ws = []
                        bs = []

                        # create weights for kernels
                        for i in range(len(self.scale_kernel_sizes[scale_num])):
                            ws.append(w([self.scale_kernel_sizes[scale_num][i],
                                         self.scale_kernel_sizes[scale_num][i],
                                         self.scale_layer_fms[scale_num][i],
                                         self.scale_layer_fms[scale_num][i + 1]]))
                            bs.append(b([self.scale_layer_fms[scale_num][i + 1]]))

                    with tf.name_scope('calculation'):
                        def calculate(height, width, inputs, gts, last_gen_frames):
                            # scale inputs and gts
                            scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num)
                            scale_height = int(height * scale_factor)
                            scale_width = int(width * scale_factor)

                            inputs = tf.image.resize_images(inputs, [scale_height, scale_width])
                            scale_gts = tf.image.resize_images(gts, [scale_height, scale_width])

                            # for all scales but the first, add the frame generated by the last
                            # scale to the input
                            if scale_num > 0:
                                last_gen_frames = tf.image.resize_images(
                                    last_gen_frames,[scale_height, scale_width])
                                print("inputs: {}, frames: {}".format(inputs.shape, last_gen_frames.shape))
                                inputs = tf.concat([inputs, last_gen_frames], 3)

                            # generated frame predictions
                            preds = inputs

                            # perform convolutions
                            with tf.name_scope('convolutions'):
                                for i in range(len(self.scale_kernel_sizes[scale_num])):
                                    # Convolve layer
                                    preds = tf.nn.conv2d(
                                        preds, ws[i], [1, 1, 1, 1], padding=c.PADDING_G)

                                    # Activate with ReLU (or Tanh for last layer)
                                    if i == len(self.scale_kernel_sizes[scale_num]) - 1:
                                        preds = tf.nn.tanh(preds + bs[i])
                                    else:
                                        preds = tf.nn.relu(preds + bs[i])

                            return preds, scale_gts

                        ##
                        # Perform train calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_train = self.scale_preds_train[scale_num - 1]
                        else:
                            last_scale_pred_train = None

                        # calculate
                        train_preds, train_gts = calculate(self.height_train,
                                                           self.width_train,
                                                           self.input_frames_train,
                                                           self.gt_frames_train,
                                                           last_scale_pred_train)
                        self.scale_preds_train.append(train_preds)
                        self.scale_gts_train.append(train_gts)

                        # We need to run the network first to get generated frames, run the
                        # discriminator on those frames to get d_scale_preds, then run this
                        # again for the loss optimization.
                        if c.ADVERSARIAL:
                            self.d_scale_preds.append(tf.placeholder(tf.float32, [None, 1]))

                        ##
                        # Perform test calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_test = self.scale_preds_test[scale_num - 1]
                        else:
                            last_scale_pred_test = None

                        # calculate
                        test_preds, test_gts = calculate(self.height_test,
                                                         self.width_test,
                                                         self.input_frames_test,
                                                         self.gt_frames_test,
                                                         last_scale_pred_test)
                        self.scale_preds_test.append(test_preds)
                        self.scale_gts_test.append(test_gts)

            ##
            # Training
            ##

            with tf.name_scope('train'):
                # global loss is the combined loss from every scale network
                self.global_loss = combined_loss(self.scale_preds_train,
                                                 self.scale_gts_train,
                                                 self.d_scale_preds)
                self.global_step = tf.Variable(0, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(learning_rate=c.LRATE_G, name='optimizer')
                self.train_op = self.optimizer.minimize(self.global_loss,
                                                        global_step=self.global_step,
                                                        name='train_op')

                # train loss summary
                loss_summary = tf.summary.scalar('train_loss_G', self.global_loss)
                self.summaries_train.append(loss_summary)

            ##
            # Error
            ##

            with tf.name_scope('error'):
                # error computation
                # get error at largest scale
                self.psnr_error_train = psnr_error(self.scale_preds_train[-1],
                                                   self.gt_frames_train)
                self.sharpdiff_error_train = sharp_diff_error(self.scale_preds_train[-1],
                                                              self.gt_frames_train)
                self.psnr_error_test = psnr_error(self.scale_preds_test[-1],
                                                  self.gt_frames_test)
                self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1],
                                                             self.gt_frames_test)
                # train error summaries
                summary_psnr_train = tf.summary.scalar('train_PSNR',
                                                       self.psnr_error_train)
                summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff',
                                                            self.sharpdiff_error_train)
                self.summaries_train += [summary_psnr_train, summary_sharpdiff_train]

                # test error
                summary_psnr_test = tf.summary.scalar('test_PSNR',
                                                      self.psnr_error_test)
                summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff',
                                                           self.sharpdiff_error_test)
                self.summaries_test += [summary_psnr_test, summary_sharpdiff_test]

            # add summaries to visualize in TensorBoard
            self.summaries_train = tf.summary.merge(self.summaries_train)
            self.summaries_test = tf.summary.merge(self.summaries_test)
Beispiel #2
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('data'):
                self.inputs = tf.placeholder(tf.float32, shape=[None, 6])
                self.gt_frames = tf.placeholder(
                    tf.float32, shape=[None, self.height, self.width, 3])

                # use variable batch_size for more flexibility
                self.batch_size = tf.shape(self.inputs)[0]

            ##
            # Scale network setup and calculation
            ##

            self.summaries = []
            self.scale_preds = []  # the generated images at each scale
            self.scale_gts = []  # the ground truth images at each scale
            self.d_scale_preds = [
            ]  # the predictions from the discriminator model

            for scale_num in xrange(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        with tf.name_scope('fully-connected'):
                            fc_ws = []
                            fc_bs = []

                            # create weights for fc layers
                            for i in xrange(
                                    len(self.scale_fc_layer_sizes[scale_num]) -
                                    1):
                                fc_ws.append(
                                    w([
                                        self.scale_fc_layer_sizes[scale_num]
                                        [i],
                                        self.scale_fc_layer_sizes[scale_num][i
                                                                             +
                                                                             1]
                                    ]))
                                fc_bs.append(
                                    b([
                                        self.scale_fc_layer_sizes[scale_num][i
                                                                             +
                                                                             1]
                                    ]))

                        with tf.name_scope('convolutions'):
                            conv_ws = []
                            conv_bs = []

                            # create weights for kernels
                            for i in xrange(
                                    len(self.scale_kernel_sizes[scale_num])):
                                conv_ws.append(
                                    w([
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_conv_layer_fms[scale_num]
                                        [i],
                                        self.scale_conv_layer_fms[scale_num][i
                                                                             +
                                                                             1]
                                    ]))
                                conv_bs.append(
                                    b([
                                        self.scale_conv_layer_fms[scale_num][i
                                                                             +
                                                                             1]
                                    ]))

                    with tf.name_scope('calculation'):

                        def calculate(height, width, inputs, gts,
                                      last_gen_frames):
                            # scale inputs and gts
                            scale_factor = 1. / 2**(
                                (self.num_scale_nets - 1) - scale_num)
                            scale_height = int(height * scale_factor)
                            scale_width = int(width * scale_factor)

                            scale_gts = tf.image.resize_images(
                                gts, scale_height, scale_width)

                            # for all scales but the first, add the frame generated by the last
                            # scale to the input
                            # if scale_num > 0:
                            #     last_gen_frames = tf.image.resize_images(last_gen_frames,
                            #                                              scale_height,
                            #                                              scale_width)
                            #     inputs = tf.concat(3, [inputs, last_gen_frames])

                            # generated frame predictions
                            preds = inputs

                            # perform fc multiplications
                            with tf.name_scope('fully-connected'):
                                for i in xrange(
                                        len(self.
                                            scale_fc_layer_sizes[scale_num]) -
                                        1):
                                    preds = tf.nn.relu(
                                        tf.matmul(preds, fc_ws[i]) + fc_bs[i])

                                # reshape for convolutions
                                preds = tf.reshape(preds, [
                                    -1, c.FRAME_HEIGHT, c.FRAME_WIDTH,
                                    self.scale_conv_layer_fms[scale_num][0]
                                ])

                            # perform convolutions
                            with tf.name_scope('convolutions'):
                                for i in xrange(
                                        len(self.scale_kernel_sizes[scale_num])
                                ):
                                    # Convolve layer
                                    preds = tf.nn.conv2d(preds,
                                                         conv_ws[i],
                                                         [1, 1, 1, 1],
                                                         padding=c.PADDING_G)

                                    # Activate with ReLU (or Tanh for last layer)
                                    if i == len(
                                            self.scale_kernel_sizes[scale_num]
                                    ) - 1:
                                        preds = tf.nn.tanh(preds + conv_bs[i])
                                    else:
                                        preds = tf.nn.relu(preds + conv_bs[i])

                            return preds, scale_gts

                        ##
                        # Perform train calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred = self.scale_preds[scale_num - 1]
                        else:
                            last_scale_pred = None

                        # calculate
                        train_preds, train_gts = calculate(
                            self.height, self.width, self.inputs,
                            self.gt_frames, last_scale_pred)
                        self.scale_preds.append(train_preds)
                        self.scale_gts.append(train_gts)

                        # We need to run the network first to get generated frames, run the
                        # discriminator on those frames to get d_scale_preds, then run this
                        # again for the loss optimization.
                        if c.ADVERSARIAL:
                            self.d_scale_preds.append(
                                tf.placeholder(tf.float32, [None, 1]))

            ##
            # Training
            ##

            with tf.name_scope('train'):
                # global loss is the combined loss from every scale network
                self.global_loss = combined_loss(self.scale_preds,
                                                 self.scale_gts,
                                                 self.d_scale_preds)
                self.global_step = tf.Variable(0, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(
                    learning_rate=c.LRATE_G, name='optimizer')
                self.train_op = self.optimizer.minimize(
                    self.global_loss,
                    global_step=self.global_step,
                    name='train_op')

                # train loss summary
                loss_summary = tf.scalar_summary('train_loss_G',
                                                 self.global_loss)
                self.summaries.append(loss_summary)

            ##
            # Error
            ##

            with tf.name_scope('error'):
                # error computation
                # get error at largest scale
                self.psnr_error = psnr_error(self.scale_preds[-1],
                                             self.gt_frames)
                self.sharpdiff_error = sharp_diff_error(
                    self.scale_preds[-1], self.gt_frames)

                # train error summaries
                summary_psnr = tf.scalar_summary('train_PSNR', self.psnr_error)
                summary_sharpdiff = tf.scalar_summary('train_SharpDiff',
                                                      self.sharpdiff_error)
                self.summaries += [summary_psnr, summary_sharpdiff]

            # add summaries to visualize in TensorBoard
            self.summaries = tf.merge_summary(self.summaries)