コード例 #1
0
def save(artist, model_path, num_save):
    sample_save_dir = c.get_dir('../save/samples/')
    sess = tf.Session()

    print artist

    data_reader = DataReader(artist)
    vocab = data_reader.get_vocab()

    print 'Init model...'
    model = LSTMModel(sess,
                      vocab,
                      c.BATCH_SIZE,
                      c.SEQ_LEN,
                      c.CELL_SIZE,
                      c.NUM_LAYERS,
                      test=True)

    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())

    saver.restore(sess, model_path)
    print 'Model restored from ' + model_path

    artist_save_dir = c.get_dir(join(sample_save_dir, artist))
    for i in xrange(num_save):
        print i

        path = join(artist_save_dir, str(i) + '.txt')
        sample = model.generate()
        processed_sample = process_sample(sample)

        with open(path, 'w') as f:
            f.write(processed_sample)
コード例 #2
0
ファイル: g_model_wip.py プロジェクト: qui3n/Code
    def test_batch(self, batch, global_step, save_imgs=True):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [batch_size x self.height x self.width x (3 * (c.HIST_LEN + 1))].
                      A batch of the input and output frames, concatenated along the channel axis
                      (index 3).
        @param global_step: The global step.
        @param save_imgs: Whether or not to save the input/output images to file. Default = True.

        @return: A tuple of (psnr error, sharpdiff error) for the batch.
        """
        print('-' * 30)
        print('Testing:')

        ##
        # Split into inputs and outputs
        ##

        input_frames = batch[:, :, :, :3 * c.HIST_LEN]
        gt_frames = batch[:, :, :, 3 * c.HIST_LEN:3 * (c.HIST_LEN + 1)]

        ##
        # Run the network
        ##
        feed_dict = {self.input_frames_test: input_frames,
                     self.gt_frames_test: gt_frames}
        gen_imgs, psnr, sharpdiff, summaries = self.sess.run([self.scale_preds_test[-1],
                                                           self.psnr_error_test,
                                                           self.sharpdiff_error_test,
                                                           self.summaries_test],
                                                           feed_dict=feed_dict)
        print('PSNR Error     : ', psnr)
        print('Sharpdiff Error: ', sharpdiff)

        ##
        # Save images
        ##

        if save_imgs:
            for pred_num in range(len(input_frames)):
                pred_dir = c.get_dir(os.path.join(
                    c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num)))

                # save input images
                for frame_num in range(c.HIST_LEN):
                    img = input_frames[pred_num, :, :, (frame_num * 3):((frame_num + 1) * 3)]
                    imsave(os.path.join(pred_dir, 'input_' + str(frame_num) + '.png'), img)

                # save output and gt images
                gen_img = gen_imgs[pred_num]
                gt_img = gt_frames[pred_num, :, :, :]
                imsave(os.path.join(pred_dir, 'gen.png'), gen_img)
                imsave(os.path.join(pred_dir, 'gt.png'), gt_img)

        print('-' * 30)
コード例 #3
0
    def generator_train_step(self, batch):

        feed_dict = {self.input_image: batch, self.is_train: True}

        if c.ADVERSARIAL:
            _, prediction, generator_loss, loss_adv_g, loss_recon, loss_entropy, precision, summaries  = \
                self.sess.run([self.train_op_G,
                            self.prediction,
                            self.generator_loss,
                            self.loss_adv_G,
                            self.loss_recon,
                            self.loss_entropy,
                            self.precision,
                            self.summaries],
                            feed_dict=feed_dict)

            self.update_discrim_accur_MA(precision)
        else:
            _, prediction, generator_loss, loss_recon, loss_entropy, summaries = \
                self.sess.run([self.train_op_G,
                            self.prediction,
                            self.generator_loss,
                            self.loss_recon,
                            self.loss_entropy,
                            self.summaries],
                            feed_dict=feed_dict)

        global_step = self.get_global_step()

        self.most_recent_summary = summaries

        if global_step % c.STATS_FREQ == 0:
            print('GeneratorModel : Step ', global_step)
            print('                 Generator Loss    : ', generator_loss)
            print('                                     ')
            #Need to add more user output for train time
        if global_step % c.IMG_SAVE_FREQ == 0:
            print('Saving images ..')
            for image in range(len(batch)):
                pred_dir = c.get_dir(
                    os.path.join(c.IMG_SAVE_DIR, 'Step_' + str(global_step),
                                 str(image)))

                #save the input image
                img = batch[image, :, :, :]
                img = denormalize_frames_mine(img)
                imsave(os.path.join(pred_dir, 'input.png'), img)

                #save the compressed image
                compressed_image = denormalize_frames_mine(
                    prediction[image, :, :, :])
                imsave(os.path.join(pred_dir, 'compressed.png'),
                       compressed_image)
            print('Saved images!')

        return global_step
コード例 #4
0
def main():
    ##
    # Handle command line input
    ##

    num_clips = 5000000

    try:
        opts, _ = getopt.getopt(sys.argv[1:], 'w:n:t:c:oH', [
            'num_workers=', 'num_clips=', 'train_dir=', 'clips_dir=',
            'overwrite', 'help'
        ])
    except getopt.GetoptError:
        usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-w', '--num_workers'):
            c.NUM_WORKERS = int(arg)
        if opt in ('-n', '--num_clips'):
            num_clips = int(arg)
        if opt in ('-t', '--train_dir'):
            c.TRAIN_DIR = c.get_dir(arg)
        if opt in ('-c', '--clips_dir'):
            c.TRAIN_DIR_CLIPS = c.get_dir(arg)
        if opt in ('-o', '--overwrite'):
            c.clear_dir(c.TRAIN_DIR_CLIPS)
        if opt in ('-H', '--help'):
            usage()
            sys.exit(2)

    # set train frame dimensions
    assert os.path.exists(c.TRAIN_DIR)
    c.FULL_HEIGHT, c.FULL_WIDTH = c.get_train_frame_dims()

    ##
    # Process data for training
    ##

    process_training_data(num_clips)
コード例 #5
0
def main():
    ##
    # Handle command line input, generate the video for the selected directory containing a video file.
    ##

    try:
        opts, _ = getopt.getopt(sys.argv[1:], 'v:H', ['video_dir=', 'help'])
    except getopt.GetoptError:
        usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-v', '--video_dir'):
            source = c.get_dir(arg)
        if opt in ('-H', '--help'):
            usage()
            sys.exit(2)

    for root, dirs, filenames in os.walk(source):
        for f in filenames:
            print('Reading video file:', f)
            name = str(f).split(".")
            fullpath = os.path.join(source, f)
            vidcap = cv2.VideoCapture(fullpath)
            success, image = vidcap.read()
            count = 0
            success = True
            while success:
                success, image = vidcap.read()
                #print ('Read a new frame: ', success)
                newDir = os.path.join(source, name[0])
                newFileName = name[0] + "_frame" + str(count) + ".jpg"
                #print('file name:', newFileName)
                finalFilePath = os.path.join(newDir, newFileName)
                #print('final path:', finalFilePath)
                if not os.path.exists(newDir):
                    print('creating dir:', newDir)
                    os.makedirs(newDir)
                cv2.imwrite(finalFilePath, image)  # save frame as JPEG file
                count += 1
コード例 #6
0
    def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [batch_size x self.height x self.width x (3 * (c.HIST_LEN+ num_rec_out))].
                      A batch of the input and output frames, concatenated along the channel axis
                      (index 3).
        @param global_step: The global step.
        @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively,
                            using previously-generated frames as input. Default = 1.
        @param save_imgs: Whether or not to save the input/output images to file. Default = True.

        @return: A tuple of (psnr error, sharpdiff error) for the batch.
        """
        if num_rec_out < 1:
            raise ValueError('num_rec_out must be >= 1')

        print('-' * 30)
        print('Testing:')

        ##
        # Split into inputs and outputs
        ##

        input_frames = batch[:, :, :, :3 * c.HIST_LEN]
        gt_frames = batch[:, :, :, 3 * c.HIST_LEN:]

        ##
        # Generate num_rec_out recursive predictions
        ##

        working_input_frames = deepcopy(input_frames)  # input frames that will shift w/ recursion
        rec_preds = []
        rec_summaries = []
        for rec_num in range(num_rec_out):
            working_gt_frames = gt_frames[:, :, :, 3 * rec_num:3 * (rec_num + 1)]

            feed_dict = {self.input_frames_test: working_input_frames,
                         self.gt_frames_test: working_gt_frames}
            preds, psnr, sharpdiff, summaries = self.sess.run([self.scale_preds_test[-1],
                                                               self.psnr_error_test,
                                                               self.sharpdiff_error_test,
                                                               self.summaries_test],
                                                              feed_dict=feed_dict)

            # remove first input and add new pred as last input
            working_input_frames = np.concatenate(
                [working_input_frames[:, :, :, 3:], preds], axis=3)

            # add predictions and summaries
            rec_preds.append(preds)
            rec_summaries.append(summaries)

            print('Recursion ', rec_num)
            print('PSNR Error     : ', psnr)
            print('Sharpdiff Error: ', sharpdiff)

        # write summaries
        # TODO: Think of a good way to write rec output summaries - rn, just using first output.
        self.summary_writer.add_summary(rec_summaries[0], global_step)

        ##
        # Save images
        ##

        if save_imgs:
            for pred_num in range(len(input_frames)):
                pred_dir = c.get_dir(os.path.join(
                    c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num)))

                # save input images
                for frame_num in range(c.HIST_LEN):
                    img = input_frames[pred_num, :, :, (frame_num * 3):((frame_num + 1) * 3)]
                    imsave(os.path.join(pred_dir, 'input_' + str(frame_num) + '.png'), img)

                # save recursive outputs
                for rec_num in range(num_rec_out):
                    gen_img = rec_preds[rec_num][pred_num]
                    gt_img = gt_frames[pred_num, :, :, 3 * rec_num:3 * (rec_num + 1)]
                    imsave(os.path.join(pred_dir, 'gen_' + str(rec_num) + '.png'), gen_img)
                    imsave(os.path.join(pred_dir, 'gt_' + str(rec_num) + '.png'), gt_img)

        print('-' * 30)
コード例 #7
0
    def train_step(self, batch, discriminator=None):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [c.BATCH_SIZE x self.height x self.width x (3 * (c.HIST_LEN + 1))].
                      The input and output frames, concatenated along the channel axis (index 3).
        @param discriminator: The discriminator model. Default = None, if not adversarial.

        @return: The global step.
        """
        ##
        # Split into inputs and outputs
        ##

        input_frames = batch[:, :, :, :-3]
        gt_frames = batch[:, :, :, -3:]

        ##
        # Train
        ##

        feed_dict = {self.input_frames_train: input_frames, self.gt_frames_train: gt_frames}

        if c.ADVERSARIAL:
            # Run the generator first to get generated frames
            scale_preds = self.sess.run(self.scale_preds_train, feed_dict=feed_dict)

            # Run the discriminator nets on those frames to get predictions
            d_feed_dict = {}
            for scale_num, gen_frames in enumerate(scale_preds):
                d_feed_dict[discriminator.scale_nets[scale_num].input_frames] = gen_frames
            d_scale_preds = self.sess.run(discriminator.scale_preds, feed_dict=d_feed_dict)

            # Add discriminator predictions to the
            for i, preds in enumerate(d_scale_preds):
                feed_dict[self.d_scale_preds[i]] = preds

        _, global_loss, global_psnr_error, global_sharpdiff_error, global_step, summaries = \
            self.sess.run([self.train_op,
                           self.global_loss,
                           self.psnr_error_train,
                           self.sharpdiff_error_train,
                           self.global_step,
                           self.summaries_train],
                          feed_dict=feed_dict)

        ##
        # User output
        ##
        if global_step % c.STATS_FREQ == 0:
            print('GeneratorModel : Step ', global_step)
            print('                 Global Loss    : ', global_loss)
            print('                 PSNR Error     : ', global_psnr_error)
            print('                 Sharpdiff Error: ', global_sharpdiff_error)
        if global_step % c.SUMMARY_FREQ == 0:
            self.summary_writer.add_summary(summaries, global_step)
            print('GeneratorModel: saved summaries')
        if global_step % c.IMG_SAVE_FREQ == 0:
            print('-' * 30)
            print('Saving images...')

            # if not adversarial, we didn't get the preds for each scale net before for the
            # discriminator prediction, so do it now
            if not c.ADVERSARIAL:
                scale_preds = self.sess.run(self.scale_preds_train, feed_dict=feed_dict)

            # re-generate scale gt_frames to avoid having to run through TensorFlow.
            scale_gts = []
            for scale_num in range(self.num_scale_nets):
                scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num)
                scale_height = int(self.height_train * scale_factor)
                scale_width = int(self.width_train * scale_factor)

                # resize gt_output_frames for scale and append to scale_gts_train
                scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3])
                for i, img in enumerate(gt_frames):
                    # for skimage.transform.resize, images need to be in range [0, 1], so normalize
                    # to [0, 1] before resize and back to [-1, 1] after
                    sknorm_img = (img / 2) + 0.5
                    resized_frame = resize(sknorm_img, [scale_height, scale_width, 3])
                    scaled_gt_frames[i] = (resized_frame - 0.5) * 2
                scale_gts.append(scaled_gt_frames)

            # for every clip in the batch, save the inputs, scale preds and scale gts
            for pred_num in range(len(input_frames)):
                pred_dir = c.get_dir(os.path.join(c.IMG_SAVE_DIR, 'Step_' + str(global_step),
                                                  str(pred_num)))

                # save input images
                for frame_num in range(c.HIST_LEN):
                    img = input_frames[pred_num, :, :, (frame_num * 3):((frame_num + 1) * 3)]
                    imsave(os.path.join(pred_dir, 'input_' + str(frame_num) + '.png'), img)

                # save preds and gts at each scale
                # noinspection PyUnboundLocalVariable
                for scale_num, scale_pred in enumerate(scale_preds):
                    gen_img = scale_pred[pred_num]

                    path = os.path.join(pred_dir, 'scale' + str(scale_num))
                    gt_img = scale_gts[scale_num][pred_num]

                    imsave(path + '_gen.png', gen_img)
                    imsave(path + '_gt.png', gt_img)

            print('Saved images!')
            print('-' * 30)

        return global_step
コード例 #8
0
    def test_batch(self, batch, global_step, save_feature_maps=True,save_imgs=True):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [batch_size x self.height x self.width x (3 * (c.HIST_LEN+ num_rec_out))].
                      A batch of the input and output frames, concatenated along the channel axis
                      (index 3).
        @param global_step: The global step.
        @param save_imgs: Whether or not to save the input/output images to file. Default = True.

        @return: A tuple of (psnr error, sharpdiff error) for the batch.
        """

        print('-' * 30)
        print('Testing:')

        ##
        # Split into inputs and outputs
        ##

        ref_image = batch['ref_image']
        multi_plane = batch['multi_plane']
        gt = batch['gt']
        batch_size = ref_image.shape[0]

        feed_dict = {self.ref_image: ref_image, 
                     self.multi_plane: multi_plane, 
                     self.gt: gt}
        
        if save_feature_maps:
            preds, psnr, sharpdiff, ssim, feature_maps, blending_weights, alpha_images = self.sess.run([self.preds,
                                                                        self.psnr,
                                                                        self.sharpdiff,
                                                                        self.ssim,
                                                                        self.feature_maps,
                                                                        self.blending_weights,
                                                                        self.alpha_images],
                                                                        feed_dict=feed_dict)

        else:
            # Run the generator first to get generated images
            preds, psnr, sharpdiff, ssim = self.sess.run([self.preds,
                                                    self.psnr,
                                                    self.sharpdiff,
                                                    self.ssim],
                                                    feed_dict=feed_dict)
        ##
        # User output
        ##

        print('PSNR      : ', psnr)
        print('Sharpdiff : ', sharpdiff)
        print('SSIM      : ', ssim)

        print('-' * 30)
        print('Saving images...')
        for pred_num in range(batch_size):
            pred_dir = c.get_dir(os.path.join(c.IMG_SAVE_DIR, 'Test/Step_' + str(global_step),
                                                str(pred_num)))

            # save input images
            ref_img = ref_image[pred_num, :, :, :]
            imsave(os.path.join(pred_dir, 'ref_image.png'), imresize(ref_img,[128,128]))

            hr_img = multi_plane_img = multi_plane[pred_num, :, :, 0:3]
            imsave(os.path.join(pred_dir, 'hr_image.png'), hr_img)

            plane_dir = c.get_dir(os.path.join(pred_dir,'planes'))
            for i in range(1,c.NUM_PLANE):
                multi_plane_img = multi_plane[pred_num, :, :, 3*i:3*(i+1)]
                imsave(os.path.join(plane_dir, 'plane_%d.png'%(i+1)), multi_plane_img)

            gt_img = gt[pred_num, :, :, :]
            imsave(os.path.join(pred_dir, 'gt.png'), gt_img)

            gen_img = preds[pred_num,:,:,:]
            imsave(os.path.join(pred_dir, 'pred.png'), gen_img)
            
            if save_feature_maps:
                feature_dir = c.get_dir(os.path.join(pred_dir,'features'))
                for layer in feature_maps:
                    layer_dir = c.get_dir(os.path.join(feature_dir,layer))
                    layer_img = feature_maps[layer][pred_num,:,:,:]
                    num_feature = (layer_img.shape)[2]
                    for k in range(num_feature):
                        imsave(os.path.join(layer_dir, '%d.bmp'%k), layer_img[:,:,k])

                blending_weights_dir = c.get_dir(os.path.join(feature_dir,'blending_weights'))
                alpha_dir = c.get_dir(os.path.join(feature_dir,'alpha'))
                for k in range(c.NUM_PLANE):
                    imsave(os.path.join(blending_weights_dir, '%d.bmp'%k), blending_weights[pred_num,:,:,k])
                    imsave(os.path.join(alpha_dir, '%d.bmp'%k), alpha_images[pred_num,:,:,k])
        print('Saved images!')
        print('-' * 30)
コード例 #9
0
    def train_step(self, batch, discriminator=None):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [c.BATCH_SIZE x self.height x self.width x (3 * (c.HIST_LEN + 1))].
                      The input and output frames, concatenated along the channel axis (index 3).
        @param discriminator: The discriminator model. Default = None, if not adversarial.

        @return: The global step.
        """
        ##
        # Split into inputs and outputs
        ##

        ref_image = batch['ref_image']
        multi_plane = batch['multi_plane']
        gt = batch['gt']
        batch_size = ref_image.shape[0]

        ##
        # Train
        ##

        feed_dict = {self.ref_image: ref_image, 
                     self.multi_plane: multi_plane, 
                     self.gt: gt}

        # Run the generator first to get generated images
        preds = self.sess.run(self.preds, feed_dict=feed_dict)

        _, loss, global_psnr, global_sharpdiff, global_ssim, global_step, summaries = \
            self.sess.run([self.train_op,
                           self.loss,
                           self.psnr,
                           self.sharpdiff,
                           self.ssim,
                           self.global_step,
                           self.summaries],
                           feed_dict=feed_dict)

        ##
        # User output
        ##
        if global_step % c.STATS_FREQ == 0:
            print('EXAMPLE-BASED EDSR FlowModel : Step ', global_step)
            print('                 Global Loss    : ', loss)
            print('                 PSNR           : ', global_psnr)
            print('                 Sharpdiff      : ', global_sharpdiff)
            print('                 SSIM           : ', global_ssim)

        if global_step % c.SUMMARY_FREQ == 0:
            self.summary_writer.add_summary(summaries, global_step)
            print('GeneratorModel: saved summaries')
        if global_step % c.IMG_SAVE_FREQ == 0:
            print('-' * 30)
            print('Saving images...')

            
            for pred_num in range(batch_size):
                pred_dir = c.get_dir(os.path.join(c.IMG_SAVE_DIR, 'Step_' + str(global_step),
                                                    str(pred_num)))

                # save input images
                ref_img = ref_image[pred_num, :, :, :]
                imsave(os.path.join(pred_dir, 'ref_image.png'), imresize(ref_img,[128,128]))

                hr_img = multi_plane_img = multi_plane[pred_num, :, :, 0:3]
                imsave(os.path.join(pred_dir, 'hr_image.png'), hr_img)

                plane_dir = c.get_dir(os.path.join(pred_dir,'planes'))
                for i in range(1,c.NUM_PLANE):
                    multi_plane_img = multi_plane[pred_num, :, :, 3*i:3*i+3]
                    imsave(os.path.join(plane_dir, 'plane_%d.png'%(i+1)), multi_plane_img)

                gt_img = gt[pred_num, :, :, :]
                imsave(os.path.join(pred_dir, 'gt.png'), gt_img)

                gen_img = preds[pred_num,:,:,:]
                imsave(os.path.join(pred_dir, 'pred.png'), gen_img)

            print('Saved images!')
            print('-' * 30)

        return global_step
コード例 #10
0
    def generator_test_batch(self, batch, global_step, save_imgs=True):
        print('Compressing test batch')

        feed_dict = {self.input_image_test: batch, self.is_train: False}

        prediction, code = self.sess.run(
            [self.prediction_test, self.code_test], feed_dict=feed_dict)

        ##
        # Everything else in the function is currently rough and needs
        # finalising
        ##

        start = timer()
        time_test_code = self.sess.run(self.code_test, feed_dict=feed_dict)
        end = timer()
        print("Time taken for compression ", end - start)

        ##Reshape the code / quantised coefficents via rasta scan ordering into
        #A continuous integer stream to be encoded by the entropy encoder

        print("Size of coffiecnt tensor")
        print(code.shape)

        code = code.astype(np.int32)
        code = np.reshape(code,
                          (-1, code.shape[1] * code.shape[2] * code.shape[3]))

        print("Size of reshaped coefficent array")
        print(code.shape)

        num_bits_total = 0
        for i in range(code.shape[0]):
            num_bits = cabac.encode(code[i, :])
            print("Bytes for image ", i, " : ", num_bits / 8)
            print("Bits per pixel image ", i, " : ",
                  num_bits / (c.FULL_HEIGHT * c.FULL_WIDTH * 3))
            print("Effective entropy ", i, " : ", num_bits / code.shape[1])
            num_bits_total += num_bits

        av_bits_per_pixel = num_bits_total / (code.shape[0] * c.FULL_HEIGHT *
                                              c.FULL_WIDTH * 3)

        entropy_total = 0
        for i in range(code.shape[0]):
            entropy = cabac.estimate_entropy(code[i, :])
            print("Entropy image ", i, " : ", entropy)
            entropy_total += entropy
        entropy = entropy_total / code.shape[0]
        print("Average entropy")
        print(entropy)

        print("Average number of bpp")
        print(av_bits_per_pixel)

        print("PSNR calcuation")
        compressed_images = deepcopy(prediction)
        input_images = deepcopy(batch)
        compressed_images = denormalize_frames_mine(compressed_images)
        input_images = denormalize_frames_mine(input_images)
        print(pc.PSNR_error_np(compressed_images, input_images))

        if save_imgs:
            for image in range(len(batch)):
                pred_dir = c.get_dir(
                    os.path.join(c.IMG_SAVE_DIR,
                                 'Tests/Step_' + str(global_step), str(image)))

                #Save the input image
                img = batch[image, :, :, :]
                img = denormalize_frames_mine(img)
                imsave(os.path.join(pred_dir, 'input.png'), img)

                #Save the compressed img
                compressed_image = denormalize_frames_mine(
                    prediction[image, :, :, :])
                imsave(os.path.join(pred_dir, 'compressed.png'),
                       compressed_image)
        print('Saved test batch')
コード例 #11
0
    def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True):
        """
        Runs a training step using the global loss on each of the scale networks.

        @param batch: An array of shape
                      [batch_size x self.height x self.width x (3 * (c.HIST_LEN+ num_rec_out))].
                      A batch of the input and output frames, concatenated along the channel axis
                      (index 3).
        @param global_step: The global step.
        @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively,
                            using previously-generated frames as input. Default = 1.
        @param save_imgs: Whether or not to save the input/output images to file. Default = True.

        @return: A tuple of (psnr error, sharpdiff error) for the batch.
        """
        if num_rec_out < 1:
            raise ValueError('num_rec_out must be >= 1')

        print '-' * 30
        print 'Testing:'

        ##
        # Split into inputs and outputs
        ##

        input_frames = batch[:, :, :, :3 * c.HIST_LEN]
        gt_frames = batch[:, :, :, 3 * c.HIST_LEN:]

        ##
        # Generate num_rec_out recursive predictions
        ##

        feed_dict = {
            self.input_frames_test: input_frames,
            self.gt_frames_test: gt_frames
        }
        preds, psnr, sharpdiff, ssim, summaries = self.sess.run(
            [
                self.scale_preds_test[-1], self.psnr_error_test,
                self.sharpdiff_error_test, self.ssim_error_test,
                self.summaries_test
            ],
            feed_dict=feed_dict)

        print 'SSIM Errors     : ', [ssim_error for ssim_error in ssim]
        print 'PSNR Errors     : ', [psnr_error for psnr_error in psnr]
        print 'Sharpdiff Errors: ', [
            sharpdiff_error for sharpdiff_error in sharpdiff
        ]

        # write summaries
        self.summary_writer.add_summary(summaries, global_step)

        ##
        # Save images
        ##

        if save_imgs:
            for pred_num in xrange(len(input_frames)):
                pred_dir = c.get_dir(
                    os.path.join(c.IMG_SAVE_DIR,
                                 'Tests/Step_' + str(global_step),
                                 str(pred_num)))

                # save input images
                for frame_num in xrange(c.HIST_LEN):
                    img = input_frames[pred_num, :, :,
                                       (frame_num * 3):((frame_num + 1) * 3)]
                    imsave(
                        os.path.join(pred_dir,
                                     'input_' + str(frame_num) + '.jpg'), img)

                # save recursive outputs
                for rec_num in xrange(num_rec_out):
                    gen_img = preds[pred_num, :, :,
                                    3 * rec_num:3 * (rec_num + 1)]
                    gt_img = gt_frames[pred_num, :, :,
                                       3 * rec_num:3 * (rec_num + 1)]
                    imsave(
                        os.path.join(pred_dir,
                                     '_gen_' + str(rec_num) + '.jpg'), gen_img)
                    imsave(
                        os.path.join(pred_dir, '_gt_' + str(rec_num) + '.jpg'),
                        gt_img)

        print '-' * 30