def load_training_data(self):
        with tf.name_scope("data_loading"):
            if self.config.dataset == 'DAVIS2016':
                reader = Davis2016Reader(
                    self.config.root_dir,
                    max_temporal_len=self.config.max_temporal_len,
                    min_temporal_len=self.config.min_temporal_len,
                    num_threads=self.config.num_threads)

                train_batch, train_iter = reader.image_inputs(
                    batch_size=self.config.batch_size,
                    train_crop=self.config.train_crop,
                    partition=self.config.train_partition)

                val_batch, val_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    t_len=self.config.test_temporal_shift,
                    test_crop=self.config.test_crop,
                    partition='val')
            elif self.config.dataset == 'FBMS':
                reader = FBMS59Reader(
                    self.config.root_dir,
                    max_temporal_len=self.config.max_temporal_len,
                    min_temporal_len=self.config.min_temporal_len,
                    num_threads=self.config.num_threads)
                train_batch, train_iter = reader.image_inputs(
                    batch_size=self.config.batch_size,
                    train_crop=self.config.train_crop,
                    partition=self.config.train_partition)

                val_batch, val_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    t_len=self.config.test_temporal_shift,
                    test_crop=self.config.test_crop,
                    with_fname=True,
                    partition='val')
                self.num_categories = reader.num_categories

            elif self.config.dataset == 'SEGTRACK':
                reader = SegTrackV2Reader(
                    self.config.root_dir,
                    max_temporal_len=self.config.max_temporal_len,
                    min_temporal_len=self.config.min_temporal_len,
                    num_threads=self.config.num_threads)
                train_batch, train_iter = reader.image_inputs(
                    batch_size=self.config.batch_size,
                    train_crop=self.config.train_crop)

                val_batch, val_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    t_len=self.config.test_temporal_shift,
                    test_crop=self.config.test_crop)

            else:
                raise IOError("Dataset should be DAVIS2016 / FBMS / SEGTRACK")

        self.num_samples_val = reader.val_samples
        return train_batch, val_batch, train_iter, val_iter
    def build_aug_test_graph(self):
        """This graph will be used for the generation of the results
           with multiple crop of the images.
           This improves the results while doing ensembling.
           Requires batch size to be one (automatically handled)
        """
        test_crops = [0.85, 0.9, 0.95, 1.0]
        print("Evaluating the following crops {}".format(test_crops))
        with tf.name_scope("data_loading"):
            if self.config.dataset == 'DAVIS2016':
                reader = Davis2016Reader(self.config.root_dir)
                test_batch, fname_batch, test_iter = reader.augmented_inputs(
                    t_len=self.config.test_temporal_shift,
                    test_crops=test_crops,
                    partition=self.config.test_partition)

            elif self.config.dataset == 'FBMS':
                assert 'FBMS' in self.config.root_dir
                reader = FBMS59Reader(self.config.root_dir)
                test_batch, fname_batch, test_iter = reader.augmented_inputs(
                    t_len=self.config.test_temporal_shift,
                    test_crops=test_crops,
                    partition=self.config.test_partition)
            elif self.config.dataset == 'SEGTRACK':
                reader = SegTrackV2Reader(self.config.root_dir)
                test_batch, fname_batch, test_iter = reader.augmented_inputs(
                    t_len=self.config.test_temporal_shift,
                    test_crops=test_crops)
            else:
                raise IOError("Dataset should be DAVIS2016 / FBMS / SEGTRACK")

        results = {'pred_masks': {}, 'gt_masks': {}, 'img_1s': {}}
        for crop in test_crops:
            image_batch = tf.expand_dims(test_batch['img_1s'][crop], axis=0)
            images_2_batch = tf.expand_dims(test_batch['img_2s'][crop], axis=0)
            gt_mask_batch = test_batch['seg_1s'][crop]

            flow_network = ModelPWCNet()
            flow_batch = flow_network.predict_from_img_pairs(
                image_batch, images_2_batch)

            # Reshape everything
            image_batch = tf.image.resize_images(
                image_batch, [self.config.img_height, self.config.img_width])
            flow_batch = tf.image.resize_images(
                flow_batch, [self.config.img_height, self.config.img_width])
            # Reshape mask to correct ratio
            gt_mask_batch = tf.image.resize_images(
                gt_mask_batch, [self.config.img_height, self.config.img_width],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

            # Normalize flow
            flow_batch = flow_batch / tf.constant(self.config.flow_normalizer)

            generated_masks = generator_net(
                images=image_batch,
                flows=preprocess_flow_batch(flow_batch),
                training=False,
                scope="MaskNet/",
                reuse=tf.AUTO_REUSE)

            results['pred_masks'][crop] = tf.squeeze(generated_masks, axis=0)
            results['gt_masks'][crop] = gt_mask_batch
            results['img_1s'][crop] = tf.squeeze(image_batch, axis=0)

        self.outputs = results
        self.fname_batch = tf.squeeze(fname_batch)
        self.test_samples = reader.val_samples
        self.test_crops = test_crops
        self.test_iterator = test_iter
    def build_test_graph(self):
        """This graph will be used for testing. In particular, it will
           compute the loss on a testing set, or some other utilities.
        """
        with tf.name_scope("data_loading"):
            if self.config.dataset == 'DAVIS2016':
                reader = Davis2016Reader(self.config.root_dir, num_threads=1)
                test_batch, test_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    t_len=self.config.test_temporal_shift,
                    with_fname=True,
                    test_crop=self.config.test_crop,
                    partition=self.config.test_partition)

            elif self.config.dataset == 'FBMS':
                reader = FBMS59Reader(self.config.root_dir)
                test_batch, test_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    test_crop=self.config.test_crop,
                    t_len=self.config.test_temporal_shift,
                    with_fname=True,
                    partition=self.config.test_partition)
            elif self.config.dataset == 'SEGTRACK':
                reader = SegTrackV2Reader(self.config.root_dir, num_threads=1)
                test_batch, test_iter = reader.test_inputs(
                    batch_size=self.config.batch_size,
                    test_crop=self.config.test_crop,
                    t_len=self.config.test_temporal_shift,
                    with_fname=True)
            else:
                raise IOError("Dataset should be DAVIS2016 / FBMS / SEGTRACK")

            image_batch, images_2_batch, gt_mask_batch, fname_batch = test_batch[0], \
                                            test_batch[1], test_batch[2], test_batch[3]

        # Flow computed on original image size
        flow_network = ModelPWCNet()
        flow_batch = flow_network.predict_from_img_pairs(
            image_batch, images_2_batch)

        # Reshape everything
        image_batch = tf.image.resize_images(
            image_batch, [self.config.img_height, self.config.img_width])
        flow_batch = tf.image.resize_images(
            flow_batch, [self.config.img_height, self.config.img_width])
        # Reshape mask to correct ratio
        gt_mask_batch = tf.image.resize_images(
            gt_mask_batch, [self.config.img_height, self.config.img_width],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        # Normalize flow
        flow_batch = flow_batch / tf.constant(self.config.flow_normalizer)

        with tf.name_scope("MaskNet") as scope:
            generated_masks = generator_net(
                images=image_batch,
                flows=preprocess_flow_batch(flow_batch),
                training=False,
                scope=scope,
                reuse=False)

        flow_masked = flow_batch * (1.0 - generated_masks)

        with tf.name_scope("FlownetS") as scope:
            pred_flows = recover_net(image_batch,
                                     flow_masked,
                                     mask=generated_masks,
                                     scope=scope,
                                     reuse=False)

        self.input_image = image_batch
        self.gt_flow = flow_batch
        self.fname_batch = fname_batch
        self.generated_masks = generated_masks
        self.test_samples = reader.val_samples
        self.gt_masks = gt_mask_batch
        self.pred_flow = pred_flows
        self.test_iterator = test_iter