Beispiel #1
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('data'):
                self.inputs = tf.placeholder(tf.float32, shape=[None, 6])
                self.gt_frames = tf.placeholder(
                    tf.float32, shape=[None, self.height, self.width, 3])

                # use variable batch_size for more flexibility
                self.batch_size = tf.shape(self.inputs)[0]

            ##
            # Scale network setup and calculation
            ##

            self.summaries = []
            self.scale_preds = []  # the generated images at each scale
            self.scale_gts = []  # the ground truth images at each scale
            self.d_scale_preds = [
            ]  # the predictions from the discriminator model

            for scale_num in xrange(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        with tf.name_scope('fully-connected'):
                            fc_ws = []
                            fc_bs = []

                            # create weights for fc layers
                            for i in xrange(
                                    len(self.scale_fc_layer_sizes[scale_num]) -
                                    1):
                                fc_ws.append(
                                    w([
                                        self.scale_fc_layer_sizes[scale_num]
                                        [i],
                                        self.scale_fc_layer_sizes[scale_num][i
                                                                             +
                                                                             1]
                                    ]))
                                fc_bs.append(
                                    b([
                                        self.scale_fc_layer_sizes[scale_num][i
                                                                             +
                                                                             1]
                                    ]))

                        with tf.name_scope('convolutions'):
                            conv_ws = []
                            conv_bs = []

                            # create weights for kernels
                            for i in xrange(
                                    len(self.scale_kernel_sizes[scale_num])):
                                conv_ws.append(
                                    w([
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_conv_layer_fms[scale_num]
                                        [i],
                                        self.scale_conv_layer_fms[scale_num][i
                                                                             +
                                                                             1]
                                    ]))
                                conv_bs.append(
                                    b([
                                        self.scale_conv_layer_fms[scale_num][i
                                                                             +
                                                                             1]
                                    ]))

                    with tf.name_scope('calculation'):

                        def calculate(height, width, inputs, gts,
                                      last_gen_frames):
                            # scale inputs and gts
                            scale_factor = 1. / 2**(
                                (self.num_scale_nets - 1) - scale_num)
                            scale_height = int(height * scale_factor)
                            scale_width = int(width * scale_factor)

                            scale_gts = tf.image.resize_images(
                                gts, scale_height, scale_width)

                            # for all scales but the first, add the frame generated by the last
                            # scale to the input
                            # if scale_num > 0:
                            #     last_gen_frames = tf.image.resize_images(last_gen_frames,
                            #                                              scale_height,
                            #                                              scale_width)
                            #     inputs = tf.concat(3, [inputs, last_gen_frames])

                            # generated frame predictions
                            preds = inputs

                            # perform fc multiplications
                            with tf.name_scope('fully-connected'):
                                for i in xrange(
                                        len(self.
                                            scale_fc_layer_sizes[scale_num]) -
                                        1):
                                    preds = tf.nn.relu(
                                        tf.matmul(preds, fc_ws[i]) + fc_bs[i])

                                # reshape for convolutions
                                preds = tf.reshape(preds, [
                                    -1, c.FRAME_HEIGHT, c.FRAME_WIDTH,
                                    self.scale_conv_layer_fms[scale_num][0]
                                ])

                            # perform convolutions
                            with tf.name_scope('convolutions'):
                                for i in xrange(
                                        len(self.scale_kernel_sizes[scale_num])
                                ):
                                    # Convolve layer
                                    preds = tf.nn.conv2d(preds,
                                                         conv_ws[i],
                                                         [1, 1, 1, 1],
                                                         padding=c.PADDING_G)

                                    # Activate with ReLU (or Tanh for last layer)
                                    if i == len(
                                            self.scale_kernel_sizes[scale_num]
                                    ) - 1:
                                        preds = tf.nn.tanh(preds + conv_bs[i])
                                    else:
                                        preds = tf.nn.relu(preds + conv_bs[i])

                            return preds, scale_gts

                        ##
                        # Perform train calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred = self.scale_preds[scale_num - 1]
                        else:
                            last_scale_pred = None

                        # calculate
                        train_preds, train_gts = calculate(
                            self.height, self.width, self.inputs,
                            self.gt_frames, last_scale_pred)
                        self.scale_preds.append(train_preds)
                        self.scale_gts.append(train_gts)

                        # We need to run the network first to get generated frames, run the
                        # discriminator on those frames to get d_scale_preds, then run this
                        # again for the loss optimization.
                        if c.ADVERSARIAL:
                            self.d_scale_preds.append(
                                tf.placeholder(tf.float32, [None, 1]))

            ##
            # Training
            ##

            with tf.name_scope('train'):
                # global loss is the combined loss from every scale network
                self.global_loss = combined_loss(self.scale_preds,
                                                 self.scale_gts,
                                                 self.d_scale_preds)
                self.global_step = tf.Variable(0, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(
                    learning_rate=c.LRATE_G, name='optimizer')
                self.train_op = self.optimizer.minimize(
                    self.global_loss,
                    global_step=self.global_step,
                    name='train_op')

                # train loss summary
                loss_summary = tf.scalar_summary('train_loss_G',
                                                 self.global_loss)
                self.summaries.append(loss_summary)

            ##
            # Error
            ##

            with tf.name_scope('error'):
                # error computation
                # get error at largest scale
                self.psnr_error = psnr_error(self.scale_preds[-1],
                                             self.gt_frames)
                self.sharpdiff_error = sharp_diff_error(
                    self.scale_preds[-1], self.gt_frames)

                # train error summaries
                summary_psnr = tf.scalar_summary('train_PSNR', self.psnr_error)
                summary_sharpdiff = tf.scalar_summary('train_SharpDiff',
                                                      self.sharpdiff_error)
                self.summaries += [summary_psnr, summary_sharpdiff]

            # add summaries to visualize in TensorBoard
            self.summaries = tf.merge_summary(self.summaries)
Beispiel #2
0
    def define_graph(self):
        """
        Sets up the model graph in TensorFlow.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('data'):
                self.input_frames_train = tf.placeholder(
                    tf.float32, shape=[None, self.height_train, self.width_train, 3 * c.HIST_LEN])
                self.gt_frames_train = tf.placeholder(
                    tf.float32, shape=[None, self.height_train, self.width_train, 3])

                self.input_frames_test = tf.placeholder(
                    tf.float32, shape=[None, self.height_test, self.width_test, 3 * c.HIST_LEN])
                self.gt_frames_test = tf.placeholder(
                    tf.float32, shape=[None, self.height_test, self.width_test, 3])

                # use variable batch_size for more flexibility
                self.batch_size_train = tf.shape(self.input_frames_train)[0]
                self.batch_size_test = tf.shape(self.input_frames_test)[0]

            ##
            # Scale network setup and calculation
            ##

            self.summaries_train = []
            self.scale_preds_train = []  # the generated images at each scale
            self.scale_gts_train = []  # the ground truth images at each scale
            self.d_scale_preds = []  # the predictions from the discriminator model

            self.summaries_test = []
            self.scale_preds_test = []  # the generated images at each scale
            self.scale_gts_test = []  # the ground truth images at each scale

            for scale_num in range(self.num_scale_nets):
                with tf.name_scope('scale_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        ws = []
                        bs = []

                        # create weights for kernels
                        for i in range(len(self.scale_kernel_sizes[scale_num])):
                            ws.append(w([self.scale_kernel_sizes[scale_num][i],
                                         self.scale_kernel_sizes[scale_num][i],
                                         self.scale_layer_fms[scale_num][i],
                                         self.scale_layer_fms[scale_num][i + 1]]))
                            bs.append(b([self.scale_layer_fms[scale_num][i + 1]]))

                    with tf.name_scope('calculation'):
                        def calculate(height, width, inputs, gts, last_gen_frames):
                            # scale inputs and gts
                            scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num)
                            scale_height = int(height * scale_factor)
                            scale_width = int(width * scale_factor)

                            inputs = tf.image.resize_images(inputs, [scale_height, scale_width])
                            scale_gts = tf.image.resize_images(gts, [scale_height, scale_width])

                            # for all scales but the first, add the frame generated by the last
                            # scale to the input
                            if scale_num > 0:
                                last_gen_frames = tf.image.resize_images(
                                    last_gen_frames,[scale_height, scale_width])
                                print("inputs: {}, frames: {}".format(inputs.shape, last_gen_frames.shape))
                                inputs = tf.concat([inputs, last_gen_frames], 3)

                            # generated frame predictions
                            preds = inputs

                            # perform convolutions
                            with tf.name_scope('convolutions'):
                                for i in range(len(self.scale_kernel_sizes[scale_num])):
                                    # Convolve layer
                                    preds = tf.nn.conv2d(
                                        preds, ws[i], [1, 1, 1, 1], padding=c.PADDING_G)

                                    # Activate with ReLU (or Tanh for last layer)
                                    if i == len(self.scale_kernel_sizes[scale_num]) - 1:
                                        preds = tf.nn.tanh(preds + bs[i])
                                    else:
                                        preds = tf.nn.relu(preds + bs[i])

                            return preds, scale_gts

                        ##
                        # Perform train calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_train = self.scale_preds_train[scale_num - 1]
                        else:
                            last_scale_pred_train = None

                        # calculate
                        train_preds, train_gts = calculate(self.height_train,
                                                           self.width_train,
                                                           self.input_frames_train,
                                                           self.gt_frames_train,
                                                           last_scale_pred_train)
                        self.scale_preds_train.append(train_preds)
                        self.scale_gts_train.append(train_gts)

                        # We need to run the network first to get generated frames, run the
                        # discriminator on those frames to get d_scale_preds, then run this
                        # again for the loss optimization.
                        if c.ADVERSARIAL:
                            self.d_scale_preds.append(tf.placeholder(tf.float32, [None, 1]))

                        ##
                        # Perform test calculation
                        ##

                        # for all scales but the first, add the frame generated by the last
                        # scale to the input
                        if scale_num > 0:
                            last_scale_pred_test = self.scale_preds_test[scale_num - 1]
                        else:
                            last_scale_pred_test = None

                        # calculate
                        test_preds, test_gts = calculate(self.height_test,
                                                         self.width_test,
                                                         self.input_frames_test,
                                                         self.gt_frames_test,
                                                         last_scale_pred_test)
                        self.scale_preds_test.append(test_preds)
                        self.scale_gts_test.append(test_gts)

            ##
            # Training
            ##

            with tf.name_scope('train'):
                # global loss is the combined loss from every scale network
                self.global_loss = combined_loss(self.scale_preds_train,
                                                 self.scale_gts_train,
                                                 self.d_scale_preds)
                self.global_step = tf.Variable(0, trainable=False)
                self.optimizer = tf.train.AdamOptimizer(learning_rate=c.LRATE_G, name='optimizer')
                self.train_op = self.optimizer.minimize(self.global_loss,
                                                        global_step=self.global_step,
                                                        name='train_op')

                # train loss summary
                loss_summary = tf.summary.scalar('train_loss_G', self.global_loss)
                self.summaries_train.append(loss_summary)

            ##
            # Error
            ##

            with tf.name_scope('error'):
                # error computation
                # get error at largest scale
                self.psnr_error_train = psnr_error(self.scale_preds_train[-1],
                                                   self.gt_frames_train)
                self.sharpdiff_error_train = sharp_diff_error(self.scale_preds_train[-1],
                                                              self.gt_frames_train)
                self.psnr_error_test = psnr_error(self.scale_preds_test[-1],
                                                  self.gt_frames_test)
                self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1],
                                                             self.gt_frames_test)
                # train error summaries
                summary_psnr_train = tf.summary.scalar('train_PSNR',
                                                       self.psnr_error_train)
                summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff',
                                                            self.sharpdiff_error_train)
                self.summaries_train += [summary_psnr_train, summary_sharpdiff_train]

                # test error
                summary_psnr_test = tf.summary.scalar('test_PSNR',
                                                      self.psnr_error_test)
                summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff',
                                                           self.sharpdiff_error_test)
                self.summaries_test += [summary_psnr_test, summary_sharpdiff_test]

            # add summaries to visualize in TensorBoard
            self.summaries_train = tf.summary.merge(self.summaries_train)
            self.summaries_test = tf.summary.merge(self.summaries_test)
Beispiel #3
0
    def define_graph(self, discriminator):
        """
        Sets up the model graph in TensorFlow.

        @param discriminator: The discriminator model that discriminates frames generated by this
                              model.
        """
        with tf.name_scope('generator'):
            ##
            # Data
            ##

            with tf.name_scope('input'):
                self.input_frames_train = tf.placeholder(
                    tf.float32,
                    shape=[
                        None, self.height_train, self.width_train,
                        3 * c.HIST_LEN
                    ],
                    name='input_frames_train')
                self.gt_frames_train = tf.placeholder(tf.float32,
                                                      shape=[
                                                          None,
                                                          self.height_train,
                                                          self.width_train,
                                                          3 * c.GT_LEN
                                                      ],
                                                      name='gt_frames_train')

                self.input_frames_test = tf.placeholder(
                    tf.float32,
                    shape=[
                        None, self.height_test, self.width_test, 3 * c.HIST_LEN
                    ],
                    name='input_frames_test')
                self.gt_frames_test = tf.placeholder(tf.float32,
                                                     shape=[
                                                         None,
                                                         self.height_test,
                                                         self.width_test,
                                                         3 * c.GT_LEN
                                                     ],
                                                     name='gt_frames_test')

                # use variable batch_size for more flexibility
                with tf.name_scope('batch_size_train'):
                    self.batch_size_train = tf.shape(
                        self.input_frames_train,
                        name='input_frames_train_shape')[0]
                with tf.name_scope('batch_size_test'):
                    self.batch_size_test = tf.shape(
                        self.input_frames_test,
                        name='input_frames_test_shape')[0]

            ##
            # Scale network setup and calculation
            ##

            self.train_vars = [
            ]  # the variables to train in the optimization step

            self.summaries_train = []
            self.scale_preds_train = []  # the generated images at each scale
            self.scale_gts_train = []  # the ground truth images at each scale
            self.d_scale_preds = [
            ]  # the predictions from the discriminator model

            self.summaries_test = []
            self.scale_preds_test = []  # the generated images at each scale
            self.scale_gts_test = []  # the ground truth images at each scale

            self.ws = []
            self.bs = []
            for scale_num in xrange(self.num_scale_nets):
                with tf.name_scope('scale_net_' + str(scale_num)):
                    with tf.name_scope('setup'):
                        scale_ws = []
                        scale_bs = []

                        # create weights for kernels
                        with tf.name_scope('weights'):
                            for i in xrange(
                                    len(self.scale_kernel_sizes[scale_num])):
                                scale_ws.append(
                                    w([
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_kernel_sizes[scale_num][i],
                                        self.scale_layer_fms[scale_num][i],
                                        self.scale_layer_fms[scale_num][i + 1]
                                    ], 'gen_' + str(scale_num) + '_' + str(i)))

                        with tf.name_scope('biases'):
                            for i in xrange(
                                    len(self.scale_kernel_sizes[scale_num])):
                                scale_bs.append(
                                    b([self.scale_layer_fms[scale_num][i + 1]
                                       ]))

                        # add to trainable parameters
                        self.train_vars += scale_ws
                        self.train_vars += scale_bs

                        self.ws.append(scale_ws)
                        self.bs.append(scale_bs)

                    with tf.name_scope('calculation'):
                        with tf.name_scope('calculation_train'):
                            ##
                            # Perform train calculation
                            ##
                            if scale_num > 0:
                                last_scale_pred_train = self.scale_preds_train[
                                    scale_num - 1]
                            else:
                                last_scale_pred_train = None

                            train_preds, train_gts = self.generate_predictions(
                                scale_num, self.height_train, self.width_train,
                                self.input_frames_train, self.gt_frames_train,
                                last_scale_pred_train)

                        with tf.name_scope('calculation_test'):
                            ##
                            # Perform test calculation
                            if scale_num > 0:
                                last_scale_pred_test = self.scale_preds_test[
                                    scale_num - 1]
                            else:
                                last_scale_pred_test = None

                            test_preds, test_gts = self.generate_predictions(
                                scale_num, self.height_test, self.width_test,
                                self.input_frames_test, self.gt_frames_test,
                                last_scale_pred_test, 'test')

                        self.scale_preds_train.append(train_preds)
                        self.scale_gts_train.append(train_gts)

                        self.scale_preds_test.append(test_preds)
                        self.scale_gts_test.append(test_gts)

            ##
            # Get Discriminator Predictions
            ##

            if c.ADVERSARIAL:
                with tf.name_scope('d_preds'):
                    # A list of the prediction tensors for each scale network
                    self.d_scale_preds = []

                    for scale_num in xrange(self.num_scale_nets):
                        with tf.name_scope('scale_' + str(scale_num)):
                            with tf.name_scope('calculation'):
                                input_scale_factor = 1. / self.scale_gt_inverse_scale_factor[
                                    scale_num]
                                input_scale_height = int(self.height_train *
                                                         input_scale_factor)
                                input_scale_width = int(self.width_train *
                                                        input_scale_factor)

                                scale_inputs_train = tf.image.resize_images(
                                    self.input_frames_train,
                                    [input_scale_height, input_scale_width])

                                # get predictions from the d scale networks
                                self.d_scale_preds.append(
                                    discriminator.scale_nets[scale_num].
                                    generate_all_predictions(
                                        scale_inputs_train,
                                        self.scale_preds_train[scale_num]))

            ##
            # Training
            ##

            with tf.name_scope('training'):
                # global loss is the combined loss from every scale network
                self.global_loss = temporal_combined_loss(
                    self.scale_preds_train, self.scale_gts_train,
                    self.d_scale_preds)

                with tf.name_scope('train_step'):
                    self.global_step = tf.Variable(0,
                                                   trainable=False,
                                                   name='global_step')
                    self.optimizer = tf.train.AdamOptimizer(
                        learning_rate=c.LRATE_G, name='optimizer')

                    self.train_op = self.optimizer.minimize(
                        self.global_loss,
                        global_step=self.global_step,
                        var_list=self.train_vars,
                        name='train_op')

                    # train loss summary
                    loss_summary = tf.summary.scalar('train_loss_G',
                                                     self.global_loss)
                    self.summaries_train.append(loss_summary)

            ##
            # Error
            ##

            with tf.name_scope('error'):
                # error computation
                # get error at largest scale
                with tf.name_scope('psnr_train'):
                    self.psnr_error_train = []
                    for gt_num in xrange(c.GT_LEN):
                        self.psnr_error_train.append(
                            psnr_error(
                                self.scale_preds_train[-1][:, :, :, gt_num *
                                                           3:(gt_num + 1) * 3],
                                self.gt_frames_train[:, :, :, gt_num *
                                                     3:(gt_num + 1) * 3]))
                with tf.name_scope('sharpdiff_train'):
                    self.sharpdiff_error_train = []
                    for gt_num in xrange(c.GT_LEN):
                        self.sharpdiff_error_train.append(
                            sharp_diff_error(
                                self.scale_preds_train[-1][:, :, :, gt_num *
                                                           3:(gt_num + 1) * 3],
                                self.gt_frames_train[:, :, :, gt_num *
                                                     3:(gt_num + 1) * 3]))
                with tf.name_scope('ssim_train'):
                    self.ssim_error_train = []
                    for gt_num in xrange(c.GT_LEN):
                        self.ssim_error_train.append(
                            ssim_error(
                                self.scale_preds_train[-1][:, :, :, gt_num *
                                                           3:(gt_num + 1) * 3],
                                self.gt_frames_train[:, :, :, gt_num *
                                                     3:(gt_num + 1) * 3]))
                with tf.name_scope('psnr_test'):
                    self.psnr_error_test = []
                    for gt_num in xrange(c.GT_LEN):
                        self.psnr_error_test.append(
                            psnr_error(
                                self.scale_preds_test[-1][:, :, :, gt_num *
                                                          3:(gt_num + 1) * 3],
                                self.gt_frames_test[:, :, :, gt_num *
                                                    3:(gt_num + 1) * 3]))
                with tf.name_scope('sharpdiff_test'):
                    self.sharpdiff_error_test = []
                    for gt_num in xrange(c.GT_LEN):
                        self.sharpdiff_error_test.append(
                            sharp_diff_error(
                                self.scale_preds_test[-1][:, :, :, gt_num *
                                                          3:(gt_num + 1) * 3],
                                self.gt_frames_test[:, :, :, gt_num *
                                                    3:(gt_num + 1) * 3]))
                with tf.name_scope('ssim_test'):
                    self.ssim_error_test = []
                    for gt_num in xrange(c.GT_LEN):
                        self.ssim_error_test.append(
                            ssim_error(
                                self.scale_preds_test[-1][:, :, :, gt_num *
                                                          3:(gt_num + 1) * 3],
                                self.gt_frames_test[:, :, :, gt_num *
                                                    3:(gt_num + 1) * 3]))
                for gt_num in xrange(c.GT_LEN):
                    # train error summaries
                    summary_psnr_train = tf.summary.scalar(
                        'train_PSNR_' + str(gt_num),
                        self.psnr_error_train[gt_num])
                    summary_sharpdiff_train = tf.summary.scalar(
                        'train_SharpDiff_' + str(gt_num),
                        self.sharpdiff_error_train[gt_num])
                    summary_ssim_train = tf.summary.scalar(
                        'train_SSIM_' + str(gt_num),
                        self.ssim_error_train[gt_num])
                    self.summaries_train += [
                        summary_psnr_train, summary_sharpdiff_train,
                        summary_ssim_train
                    ]

                    # test error summaries
                    summary_psnr_test = tf.summary.scalar(
                        'test_PSNR_' + str(gt_num),
                        self.psnr_error_test[gt_num])
                    summary_sharpdiff_test = tf.summary.scalar(
                        'test_SharpDiff_' + str(gt_num),
                        self.sharpdiff_error_test[gt_num])
                    summary_ssim_test = tf.summary.scalar(
                        'test_SSIM_' + str(gt_num),
                        self.ssim_error_test[gt_num])
                    self.summaries_test += [
                        summary_psnr_test, summary_sharpdiff_test,
                        summary_ssim_test
                    ]

            # add summaries to visualize in TensorBoard
            self.summaries_train = tf.summary.merge(self.summaries_train)
            self.summaries_test = tf.summary.merge(self.summaries_test)