Example #1
0
def metric_interpolated_frame(mid_frames,
	rec_mid_frames):
	'''
	Take the difference between the true frames 
	and the networks interpolation

	Args:
		mid_frames: tensor, [B,inter_frames,H,W,1]
		rec_mid_frames: tensor, [B,inter_frames,H,W,1]
		

	Output:
		l2 loss between predicion and ground truth
	'''

	return [l2_loss(mid_frames,rec_mid_frames),\
			compute_psnr(mid_frames,rec_mid_frames)]
Example #2
0
def metric_repeat_fframe(fframes,mid_frames):
	'''
	Assumes we predict the first frame for all 
	intermediate ones, then takes the difference
	by broadcasting along the inter_frames axis

	Args:
		fframes: tensor, [B,H,W,1]
		mid_frames: tensor, [B,inter_frames,H,W,1]

	Output:
		l2 loss between predicion and ground truth
	'''

	return [l2_loss(mid_frames,
				tf.expand_dims(fframes,axis=1)),\
			compute_psnr(mid_frames,
				tf.expand_dims(fframes,axis=1))]
Example #3
0
def metric_weighted_frame(fframes,mid_frames,lframes):
	'''
	Do a weighted interpolation between first frame 
	and the last. Then take their difference between 
	the weighted sum and the true intermediate frames

	Args:
		fframes: tensor, [B,H,W,1]
		mid_frames: tensor, [B,inter_frames,H,W,1]
		lframes: tensor, [B,H,W,1]

	Output:
		l2 loss between weighted images and ground truth
	'''

	inter_frames = mid_frames.get_shape()[1]


	fframes_expanded = tf.expand_dims(fframes,axis=1)
	lframes_expanded = tf.expand_dims(lframes,axis=1)

	fframes_tiled = tf.tile(fframes_expanded, 
		[1,inter_frames,1,1,1])
	lframes_tiled = tf.tile(lframes_expanded, 
		[1,inter_frames,1,1,1])


	weighting = tf.range(1.0,inter_frames+1)
	weighting = tf.reshape(weighting, 
		[1,inter_frames,1,1,1])/tf.cast(
		(inter_frames+1),dtype=tf.float32)

	weighted_sum = (fframes_tiled * weighting + 
		lframes_tiled *(1-weighting))

	return [l2_loss(mid_frames,weighted_sum),\
			compute_psnr(mid_frames,weighted_sum)]
def training(args):
    
    # DIRECTORY FOR CKPTS and META FILES
    # ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    ROOT_DIR = '/media/data/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'train.tfrecords')
    VAL_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'val.tfrecords')
    CKPT_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        args.ckpt_folder_name,
        '/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer(
            [TRAIN_REC_PATH], num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size)

        val_queue = tf.train.string_input_producer(
            [VAL_REC_PATH], num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size)

        with tf.variable_scope('bipn'):
            print('TRAIN FRAMES (first):')
            train_rec_iFrames = bipn.build_bipn(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True)

        with tf.variable_scope('bipn', reuse=tf.AUTO_REUSE):
            print('VAL FRAMES (first):')
            val_rec_iFrames = bipn.build_bipn(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                is_training=False)
            
        print('Model parameters:{}'.format(
            count_parameters()))

        # Weights should be kept locally ~ 500 MB space
        with tf.variable_scope('vgg16'):
            train_iFrames_features = vgg16.build_vgg16(
                train_iFrames, end_point='conv4_3').features
        with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
            train_rec_iFrames_features = vgg16.build_vgg16(
                train_rec_iFrames, end_point='conv4_3').features


        if args.perceptual_loss_weight:
            # Weights should be kept locally ~ 500 MB space
            with tf.variable_scope('vgg16'):
                train_iFrames_features = vgg16(
                    train_iFrames,
                    end_point='conv4_3')
            with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
                train_rec_iFrames_features = vgg16(
                    train_rec_iFrames,
                    end_point='conv4_3')

        # DEFINE METRICS
        if args.loss_id == 0:
            train_loss = huber_loss(
                train_iFrames, train_rec_iFrames,
                delta=1.)
            val_loss = huber_loss(
                val_iFrames, val_rec_iFrames,
                delta=1.)

        elif args.loss_id == 1:
            train_loss = l2_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = l2_loss(
                val_iFrames, val_rec_iFrames)

        total_train_loss = train_loss
        tf.summary.scalar('train_l2_loss', train_loss)
        tf.summary.scalar('total_val_l2_loss', val_loss)

       if args.perceptual_loss_weight: 
            train_perceptual_loss = perceptual_loss(
                train_iFrames_features,
                train_rec_iFrames_features)

            tf.summary.scalar('train_perceptual_loss',\
                train_perceptual_loss)

            total_train_loss += train_perceptual_loss\
                * args.perceptual_loss_weight

        # SUMMARIES
        tf.summary.scalar('total_train_loss',\
            total_train_loss)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            CKPT_PATH + 'train',
            sess.graph)

        # DEFINE OPTIMIZER
        optimizer = get_optimizer(
            train_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate,
            use_batch_norm=True)

        init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(
            coord=coord)

        # START TRAINING HERE
        try:
            for iteration in range(args.train_iters):
                _, t_summ, t_loss = sess.run(
                    [optimizer, merged, total_train_loss])

                train_writer.add_summary(t_summ, iteration)
                print('Iter:{}/{}, Train Loss:{}'.format(
                    iteration,
                    args.train_iters,
                    t_loss))

                if iteration % args.val_every == 0:
                    v_loss = sess.run(val_loss)
                    print('Iter:{}, Val Loss:{}'.format(
                        iteration,
                        v_loss))

                if iteration % args.save_every == 0:
                    saver.save(
                        sess,
                        CKPT_PATH + 'iter:{}_val:{}'.format(
                            str(iteration),
                            str(round(v_loss, 3))))

                if iteration % args.plot_every == 0:
                    start_frames, end_frames, mid_frames,\
                        rec_mid_frames = sess.run(
                            [train_fFrames, train_lFrames,\
                                train_iFrames,\
                                train_rec_iFrames])

                    visualize_frames(
                        start_frames,
                        end_frames,
                        mid_frames,
                        rec_mid_frames,
                        iteration=iteration,
                        save_path=os.path.join(
                            CKPT_PATH,
                            'plots/'))

        except Exception as e:
            coord.request_stop(e)
def training(args):
    
    # DIRECTORY FOR CKPTS and META FILES
    ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'train.tfrecords')
    VAL_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'val.tfrecords')
    CKPT_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'skip_merge_conv_summary_image_tanh_separate_bipn_wd_l2_adam_1e-3/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer(
            [TRAIN_REC_PATH], num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size)

        val_queue = tf.train.string_input_producer(
            [VAL_REC_PATH], num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size)

        with tf.variable_scope('separate_bipn'):
            print('TRAIN FRAMES (first):')
            train_rec_iFrames = skip_merge_conv_separate_encoder_bipn.build_bipn(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True)

        with tf.variable_scope('separate_bipn', reuse=tf.AUTO_REUSE):
            print('VAL FRAMES (first):')
            val_rec_iFrames = skip_merge_conv_separate_encoder_bipn.build_bipn(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                is_training=False)
            
        print('Model parameters:{}'.format(
            count_parameters()))

        # DEFINE METRICS
        if args.loss_id == 0:
            train_loss = huber_loss(
                train_iFrames, train_rec_iFrames,
                delta=1.)
            val_loss = huber_loss(
                val_iFrames, val_rec_iFrames,
                delta=1.)

        elif args.loss_id == 1:
            train_loss = l2_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = l2_loss(
                val_iFrames, val_rec_iFrames) 

        if args.weight_decay:
            decay_loss = ridge_weight_decay(
                tf.trainable_variables())
            train_loss += args.weight_decay * decay_loss

        # SUMMARIES
        tf.summary.scalar('train_loss', train_loss)
        tf.summary.scalar('val_loss', val_loss)

        with tf.contrib.summary.\
            record_summaries_every_n_global_steps(
                n=args.summary_image_every):
            summary_true, summary_fake = visualize_tensorboard(
                train_fFrames,
                train_lFrames,
                train_iFrames,
                train_rec_iFrames,
                num_plots=3)
            tf.summary.image('true frames', summary_true)
            tf.summary.image('fake frames', summary_fake)

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            CKPT_PATH + 'train',
            sess.graph)

        # DEFINE OPTIMIZER
        optimizer = get_optimizer(
            train_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate,
            use_batch_norm=True)

        init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(
            coord=coord)

        # START TRAINING HERE
        try:
            for iteration in range(args.train_iters):
                _, t_summ, t_loss = sess.run(
                    [optimizer, merged, train_loss])

                train_writer.add_summary(t_summ, iteration)
                print('Iter:{}/{}, Train Loss:{}'.format(
                    iteration,
                    args.train_iters,
                    t_loss))

                if iteration % args.val_every == 0:
                    v_loss = sess.run(val_loss)
                    print('Iter:{}, Val Loss:{}'.format(
                        iteration,
                        v_loss))

                if iteration % args.save_every == 0:
                    saver.save(
                        sess,
                        CKPT_PATH + 'iter:{}_val:{}'.format(
                            str(iteration),
                            str(round(v_loss, 3))))

                '''
                if iteration % args.plot_every == 0:
                    start_frames, end_frames, mid_frames,\
                        rec_mid_frames = sess.run(
                            [train_fFrames, train_lFrames,\
                                train_iFrames,\
                                train_rec_iFrames])

                    visualize_frames(
                        start_frames,
                        end_frames,
                        mid_frames,
                        rec_mid_frames,
                        iteration=iteration,
                        save_path=os.path.join(
                            CKPT_PATH,
                            'plots/'))
                '''

            coord.join(threads)

        except Exception as e:
            coord.request_stop(e)
def training(args):

    # DIRECTORY FOR CKPTS and META FILES
    # ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    ROOT_DIR = '/media/data/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                                  'train.tfrecords')
    VAL_REC_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                                'val.tfrecords')
    CKPT_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                             args.ckpt_folder_name + '/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer([TRAIN_REC_PATH],
                                                     num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        val_queue = tf.train.string_input_producer([VAL_REC_PATH],
                                                   num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        # Apply gaussian blurring manually
        '''
        train_fFrames = gaussian_filter(train_fFrames, std=std_dev)
        train_lFrames = gaussian_filter(train_lFrames, std=std_dev)
        train_iFrames = gaussian_filter(train_iFrames, std=std_dev)
        val_fFrames = gaussian_filter(val_fFrames, std=std_dev)
        val_lFrames = gaussian_filter(val_lFrames, std=std_dev)
        val_iFrames = gaussian_filter(val_iFrames, std=std_dev)
        '''

        # TRAINABLE
        print('---------------------------------------------')
        print('----------------- GENERATOR -----------------')
        print('---------------------------------------------')
        with tf.variable_scope('generator'):
            train_rec_iFrames = generator_unet.build_generator(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True,
                n_IF=args.n_IF,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=True)

        print('---------------------------------------------')
        print('-------------- DISCRIMINATOR ----------------')
        print('---------------------------------------------')
        # discriminator for classifying real images
        with tf.variable_scope('discriminator'):
            train_real_output_discriminator = discriminator.build_discriminator(
                train_iFrames,
                use_batch_norm=True,
                is_training=True,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=True)
        # discriminator for classifying fake images
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            train_fake_output_discriminator = discriminator.build_discriminator(
                train_rec_iFrames,
                use_batch_norm=True,
                is_training=True,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)

        # VALIDATION
        with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
            val_rec_iFrames = generator_unet.build_generator(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                n_IF=args.n_IF,
                is_training=False,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=False)

        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            val_real_output_discriminator = discriminator.build_discriminator(
                val_iFrames,
                use_batch_norm=True,
                is_training=False,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            val_fake_output_discriminator = discriminator.build_discriminator(
                val_rec_iFrames,
                use_batch_norm=True,
                is_training=False,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)

        if args.perceptual_loss_weight:
            # Weights should be kept locally ~ 500 MB space
            with tf.variable_scope('vgg16'):
                train_iFrames_features = vgg16.build_vgg16(
                    train_iFrames,
                    end_point=args.perceptual_loss_endpoint).features
            with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
                train_rec_iFrames_features = vgg16.build_vgg16(
                    train_rec_iFrames,
                    end_point=args.perceptual_loss_endpoint).features

        print('Global parameters:{}'.format(
            count_parameters(tf.global_variables())))
        print('Learnable model parameters:{}'.format(
            count_parameters(tf.trainable_variables())))

        # DEFINE GAN losses:
        train_discri_real_loss = tf.reduce_sum(
            tf.square(train_real_output_discriminator - 1)) / (2 *
                                                               args.batch_size)
        train_discri_fake_loss = tf.reduce_sum(
            tf.square(train_fake_output_discriminator)) / (2 * args.batch_size)
        train_discriminator_loss = train_discri_real_loss + train_discri_fake_loss

        train_generator_fake_loss = tf.reduce_sum(
            tf.square(train_fake_output_discriminator - 1)) / args.batch_size
        train_reconstruction_loss = l2_loss(
            train_rec_iFrames, train_iFrames) * args.reconstruction_loss_weight
        train_generator_loss = train_generator_fake_loss + train_reconstruction_loss

        val_discri_real_loss = tf.reduce_sum(
            tf.square(val_real_output_discriminator - 1)) / (2 *
                                                             args.batch_size)
        val_discri_fake_loss = tf.reduce_sum(
            tf.square(val_fake_output_discriminator)) / (2 * args.batch_size)
        val_discriminator_loss = val_discri_real_loss + val_discri_fake_loss

        val_generator_fake_loss = tf.reduce_sum(
            tf.square(val_fake_output_discriminator - 1)) / args.batch_size
        val_reconstruction_loss = l2_loss(
            val_rec_iFrames, val_iFrames) * args.reconstruction_loss_weight
        val_generator_loss = val_generator_fake_loss + val_reconstruction_loss

        if args.perceptual_loss_weight:
            train_percp_loss = perceptual_loss(train_rec_iFrames_features,
                                               train_iFrames_features)
            train_generator_loss += args.perceptual_loss_weight * train_percp_loss

        # SUMMARIES
        tf.summary.scalar('train_discri_real_loss', train_discri_real_loss)
        tf.summary.scalar('train_discri_fake_loss', train_discri_fake_loss)
        tf.summary.scalar('train_discriminator_loss', train_discriminator_loss)
        tf.summary.scalar('train_generator_fake_loss',
                          train_generator_fake_loss)
        tf.summary.scalar('train_reconstruction_loss',
                          train_reconstruction_loss)
        tf.summary.scalar('train_generator_loss', train_generator_loss)

        tf.summary.scalar('val_discri_real_loss', val_discri_real_loss)
        tf.summary.scalar('val_discri_fake_loss', val_discri_fake_loss)
        tf.summary.scalar('val_discriminator_loss', val_discriminator_loss)
        tf.summary.scalar('val_generator_fake_loss', val_generator_fake_loss)
        tf.summary.scalar('val_reconstruction_loss', val_reconstruction_loss)
        tf.summary.scalar('val_generator_loss', val_generator_loss)

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(CKPT_PATH + 'train', sess.graph)

        # get variables responsible for generator and discriminator
        trainable_vars = tf.trainable_variables()
        generator_vars = [
            var for var in trainable_vars if 'generator' in var.name
        ]
        discriminator_vars = [
            var for var in trainable_vars if 'discriminator' in var.name
        ]

        # DEFINE OPTIMIZERS
        generator_optimizer = get_optimizer(train_generator_loss,
                                            optim_id=args.optim_id,
                                            learning_rate=args.learning_rate,
                                            use_batch_norm=True,
                                            var_list=generator_vars)
        discriminator_optimizer = get_optimizer(
            train_discriminator_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate * 2.,
            use_batch_norm=True,
            var_list=discriminator_vars)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        # START TRAINING HERE
        for iteration in range(args.train_iters):

            for d_iteration in range(args.disc_train_iters):
                disc_, td_loss = sess.run(
                    [discriminator_optimizer, train_discriminator_loss])

            gene_, tgf_loss, tr_loss, t_summ = sess.run(
                [generator_optimizer, train_generator_fake_loss,\
                    train_reconstruction_loss, merged])

            train_writer.add_summary(t_summ, iteration)

            print(
                'Iter:{}/{}, Disc. Loss:{}, Gen. Loss:{}, Rec. Loss:{}'.format(
                    iteration, args.train_iters, str(round(td_loss, 6)),
                    str(round(tgf_loss, 6)), str(round(tr_loss, 6))))

            if iteration % args.val_every == 0:
                vd_loss, vgf_loss, vr_loss = sess.run(
                    [val_discriminator_loss, val_generator_fake_loss,\
                        val_reconstruction_loss])
                print('Iter:{}, VAL Disc. Loss:{}, Gen. Loss:{}, Rec. Loss:{}'.
                      format(iteration, str(round(vd_loss, 6)),
                             str(round(vgf_loss, 6)), str(round(vr_loss, 6))))

            if iteration % args.save_every == 0:
                saver.save(
                    sess, CKPT_PATH +
                    'iter:{}_valDisc:{}_valGen:{}_valRec:{}'.format(
                        str(iteration), str(round(vd_loss, 6)),
                        str(round(vgf_loss, 6)), str(round(vr_loss, 6))))

            if iteration % args.plot_every == 0:
                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [train_fFrames, train_lFrames,\
                            train_iFrames,\
                            train_rec_iFrames])

                visualize_frames(start_frames,
                                 end_frames,
                                 mid_frames,
                                 rec_mid_frames,
                                 training=True,
                                 iteration=iteration,
                                 save_path=os.path.join(
                                     CKPT_PATH, 'train_plots/'))

                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [val_fFrames, val_lFrames,\
                            val_iFrames,
                            val_rec_iFrames])

                visualize_frames(start_frames,
                                 end_frames,
                                 mid_frames,
                                 rec_mid_frames,
                                 training=False,
                                 iteration=iteration,
                                 save_path=os.path.join(
                                     CKPT_PATH, 'validation_plots/'))

        print('Training complete.....')
Example #7
0
def cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj_sampled, pred_traj_sampled_rel, loss_mask):
    l2_loss_abs = l2_loss(pred_traj_sampled, pred_traj_gt, loss_mask, mode='sum')
    l2_loss_rel = l2_loss(pred_traj_sampled_rel, pred_traj_gt_rel, loss_mask, mode='sum')
    return l2_loss_abs, l2_loss_rel
Example #8
0
def testing(info):

    # Get the best checkpoint path
    weight_path = os.listdir(info['model_path'])

    weight_paths = [i for i in weight_path if 'meta' in i]
    di_weight = {}
    for path in weight_paths:
        di_weight[path] = int(path.split(':')[-1].split('.')[0])
    di_weight = {
        k: v
        for k, v in sorted(di_weight.items(), key=lambda item: item[1])
    }
    weight_path = [*di_weight][0]

    weight_path = weight_path[:-5]
    weight_path = os.path.join(info['model_path'], weight_path)

    n_IF = info['n_IF']
    batch_size = info['batch_size']
    # get #test_samples based on experiment
    if n_IF == 3: test_samples = 18300
    elif n_IF == 4: test_samples = 18296
    elif n_IF == 5: test_samples = 18265
    elif n_IF == 6: test_samples = 18235
    elif n_IF == 7: test_samples = 18204
    test_iters = test_samples // batch_size
    test_samples = test_samples - (test_samples % batch_size)

    # get attention
    if info['attention']:
        use_attention = 1
        if info['use_spatial_attention']:
            spatial_attention = 1
        else:
            spatial_attention = 0
    else:
        use_attention = 0
        spatial_attention = 0

    # SCOPING BEGINS HERE
    tf.reset_default_graph()
    with tf.Session() as sess:
        global_step = tf.train.get_global_step()

        test_queue = tf.train.string_input_producer([info['TEST_REC_PATH']],
                                                    num_epochs=1)
        test_fFrames, test_lFrames, test_iFrames, test_mfn =\
            read_and_decode(
                filename_queue=test_queue,
                is_training=False,
                batch_size=batch_size,
                n_intermediate_frames=n_IF,
                allow_smaller_final_batch=False)

        if info['model_name'] in ['skip', 'wnet']:
            with tf.variable_scope('separate_bipn'):
                print('TEST FRAMES (first):')
                if info['model_name'] == 'skip':
                    test_rec_iFrames = skip_separate_encoder_bipn.build_bipn(
                        test_fFrames,
                        test_lFrames,
                        use_batch_norm=True,
                        is_training=False,
                        n_IF=n_IF,
                        starting_out_channels=info['out_channels'],
                        use_attention=use_attention,
                        spatial_attention=spatial_attention,
                        is_verbose=False)

                elif info['model_name'] == 'wnet':
                    test_rec_iFrames = wnet.build_wnet(
                        test_fFrames,
                        test_lFrames,
                        use_batch_norm=True,
                        is_training=False,
                        n_IF=n_IF,
                        starting_out_channels=info['out_channels'],
                        use_attention=use_attention,
                        spatial_attention=spatial_attention,
                        is_verbose=False)

        elif info['model_name'] == 'slomo':
            with tf.variable_scope('slomo'):
                test_output = slomo.SloMo_model(test_fFrames,
                                                test_lFrames,
                                                first_kernel=7,
                                                second_kernel=5,
                                                reuse=False,
                                                t_steps=n_IF,
                                                verbose=False)
                test_rec_iFrames = test_output[0]

        elif info['model_name'] == 'bipn':
            with tf.variable_scope('bipn'):
                test_rec_iFrames = BiPN.build_bipn(test_fFrames,
                                                   test_lFrames,
                                                   n_IF=n_IF,
                                                   use_batch_norm=True,
                                                   is_training=False)

        print('Global parameters:{}'.format(
            count_parameters(tf.global_variables())))
        print('Learnable model parameters:{}'.format(
            count_parameters(tf.trainable_variables())))

        # DEFINE LOSS
        if info['loss'] == 'l2':
            test_loss = l2_loss(test_iFrames, test_rec_iFrames)
        elif info['loss'] == 'l1':
            test_loss = l1_loss(test_iFrames, test_rec_iFrames)

        # DEFINE METRICS
        repeat_fFrame = metric_repeat_fframe(test_fFrames, test_iFrames)
        repeat_lFrame = metric_repeat_lframe(test_lFrames, test_iFrames)
        weighted_frame = metric_weighted_frame(test_fFrames, test_iFrames,
                                               test_lFrames)
        inter_frame = metric_interpolated_frame(test_iFrames, test_rec_iFrames)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        # Load checkpoints
        saver.restore(sess, weight_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        metrics = {}
        metrics['learnable_parameters'] = count_parameters(
            tf.trainable_variables())
        metrics['repeat_first'] = []
        metrics['repeat_last'] = []
        metrics['weighted_frames'] = []
        metrics['inter_frames'] = []
        metrics['repeat_first_psnr'] = []
        metrics['repeat_last_psnr'] = []
        metrics['weighted_frames_psnr'] = []
        metrics['inter_frames_psnr'] = []

        print('EVALUATING:{}--------------------------->'.format(
            info['model_path']))

        # START TRAINING HERE
        for iteration in range(test_iters):

            # get frames and metrics
            start_frames, end_frames, mid_frames, rec_mid_frames,\
            repeat_first, repeat_last, weighted, true_metric = sess.run(
                [test_fFrames, test_lFrames, test_iFrames, test_rec_iFrames,\
                repeat_fFrame, repeat_lFrame, weighted_frame,\
                    inter_frame])

            samples = start_frames.shape[0]
            metrics['repeat_first'].append(repeat_first[0] * samples)
            metrics['repeat_last'].append(repeat_last[0] * samples)
            metrics['weighted_frames'].append(weighted[0] * samples)
            metrics['inter_frames'].append(true_metric[0] * samples)
            metrics['repeat_first_psnr'].append(repeat_first[1] * samples)
            metrics['repeat_last_psnr'].append(repeat_last[1] * samples)
            metrics['weighted_frames_psnr'].append(weighted[1] * samples)
            metrics['inter_frames_psnr'].append(true_metric[1] * samples)

            visualize_frames(start_frames,
                             end_frames,
                             mid_frames,
                             rec_mid_frames,
                             training=False,
                             iteration=iteration,
                             save_path=os.path.join(info['model_path'],
                                                    'test_plots' + '/'))

            if iteration % 50 == 0:
                print('{}/{} iters complete'.format(iteration, test_iters))

        print('Testing complete.....')

    # Calculate metrics:
    mean_rf = sum(metrics['repeat_first']) / test_samples
    mean_rl = sum(metrics['repeat_last']) / test_samples
    mean_wf = sum(metrics['weighted_frames']) / test_samples
    mean_if = sum(metrics['inter_frames']) / test_samples

    metrics['mean_repeat_first'] = mean_rf
    metrics['mean_repeat_last'] = mean_rl
    metrics['mean_weighted_frames'] = mean_wf
    metrics['mean_inter_frames'] = mean_if

    mean_rf_psnr = sum(metrics['repeat_first_psnr']) / test_samples
    mean_rl_psnr = sum(metrics['repeat_last_psnr']) / test_samples
    mean_wf_psnr = sum(metrics['weighted_frames_psnr']) / test_samples
    mean_if_psnr = sum(metrics['inter_frames_psnr']) / test_samples

    metrics['mean_psnr_repeat_first'] = mean_rf_psnr
    metrics['mean_psnr_repeat_last'] = mean_rl_psnr
    metrics['mean_psnr_weighted_frames'] = mean_wf_psnr
    metrics['mean_psnr_inter_frames'] = mean_if_psnr

    with open(info['model_path'] + '/evaluation.pkl', 'wb') as handle:
        pickle.dump(metrics, handle)

    print('Pickle file dumped.....')