Exemplo n.º 1
0
    def __init__(self, config, subtask, dataset, tfconfig):
        logdir = os.path.join(config.logdir, subtask)
        self.config = copy.deepcopy(config)
        self.config.subtask = subtask
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=tfconfig)
        with self.graph.as_default():
            if os.path.exists(os.path.join(logdir, "mean.h5")):
                training_mean = loadh5(os.path.join(logdir, "mean.h5"))
                training_std = loadh5(os.path.join(logdir, "std.h5"))
                print("[{}] Loaded input normalizers for testing".format(
                    subtask))
    
                # Create the model instance
                self.network = Network(self.sess, self.config, dataset, {
                                       'mean': training_mean, 'std': training_std})
            else:
                self.network = Network(self.sess, self.config, dataset)
    
            self.saver = {}
            self.best_val_loss = {}
            self.best_step = {}
            # Create the saver instance for both joint and the current subtask
            for _key in ["joint", subtask]:
                self.saver[_key] = tf.train.Saver(self.network.allparams[_key])

            # We have everything ready. We finalize and initialie the network here.
        
            self.sess.run(tf.global_variables_initializer())
            restore_res = self.restore_network()
            if not restore_res:
                raise RuntimeError("Could not load network weights!")
Exemplo n.º 2
0
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng

        # Open a tensorflow session. I like keeping things simple, so I don't
        # use a supervisor. I'm just going to do everything manually. I also
        # will just allow the gpu memory to grow
        tfconfig = tf.ConfigProto()
        if self.config.usage > 0:
            tfconfig.gpu_options.allow_growth = False
            tfconfig.gpu_options.per_process_gpu_memory_fraction = \
                self.config.usage
        else:
            tfconfig.gpu_options.allow_growth = True
        self.sess = tf.Session(config=tfconfig)

        # Create the dataset instance
        self.dataset = Dataset(self.config, rng)
        # import IPython
        # IPython.embed()
        # Create the model instance
        self.network = Network(self.sess, self.config, self.dataset)
        # Make individual saver instances and summary writers for each module
        self.saver = {}
        self.summary_writer = {}
        self.best_val_loss = {}
        self.best_step = {}
        # Saver (only if there are params)
        for _key in self.network.allparams:
            if len(self.network.allparams[_key]) > 0:
                with tf.variable_scope("saver-{}".format(_key)):
                    self.saver[_key] = tf.train.Saver(
                        self.network.allparams[_key])
        # Summary Writer
        self.summary_writer[self.config.subtask] = tf.summary.FileWriter(
            os.path.join(self.config.logdir, self.config.subtask),
            graph=self.sess.graph)
        # validation loss
        self.best_val_loss[self.config.subtask] = np.inf
        # step for each module
        self.best_step[self.config.subtask] = 0

        # We have everything ready. We finalize and initialie the network here.
        self.sess.run(tf.global_variables_initializer())

        # Enable augmentations and/or force the use of the augmented set
        self.use_aug_rot = 0
        if self.config.augment_rotations:
            self.use_aug_rot = 1
        elif self.config.use_augmented_set:
            self.use_aug_rot = -1
Exemplo n.º 3
0
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng

        # Open a tensorflow session. I like keeping things simple, so I don't
        # use a supervisor. I'm just going to do everything manually. I also
        # will just allow the gpu memory to grow
        tfconfig = tf.ConfigProto()
        tfconfig.gpu_options.allow_growth = True
        self.sess = tf.Session(config=tfconfig)

        # Create the dataset instance
        self.dataset = Dataset(self.config, rng)
        # Retrieve mean/std (yes it is hacky)
        logdir = os.path.join(self.config.logdir, self.config.subtask)
        if os.path.exists(os.path.join(logdir, "mean.h5")):
            training_mean = loadh5(os.path.join(logdir, "mean.h5"))
            training_std = loadh5(os.path.join(logdir, "std.h5"))
            print("[{}] Loaded input normalizers for testing".format(
                self.config.subtask))

            # Create the model instance
            self.network = Network(self.sess, self.config, self.dataset, {
                'mean': training_mean,
                'std': training_std
            })
        else:
            self.network = Network(self.sess, self.config, self.dataset)
        # Make individual saver instances for each module.
        self.saver = {}
        self.best_val_loss = {}
        self.best_step = {}
        # Create the saver instance for both joint and the current subtask
        for _key in ["joint", self.config.subtask]:
            self.saver[_key] = tf.train.Saver(self.network.allparams[_key])

        #print('\nNETWORK PARAMETERS')
        #for item in self.network.allparams[_key]:
        #    print(item)

        # We have everything ready. We finalize and initialie the network here.
        self.sess.run(tf.global_variables_initializer())
Exemplo n.º 4
0
class Trainer(object):
    """The Trainer Class

    LATER: Remove all unecessary "dictionarization"

    """
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng

        # Open a tensorflow session. I like keeping things simple, so I don't
        # use a supervisor. I'm just going to do everything manually. I also
        # will just allow the gpu memory to grow
        tfconfig = tf.ConfigProto()
        if self.config.usage > 0:
            tfconfig.gpu_options.allow_growth = False
            tfconfig.gpu_options.per_process_gpu_memory_fraction = \
                self.config.usage
        else:
            tfconfig.gpu_options.allow_growth = True
        self.sess = tf.Session(config=tfconfig)

        # Create the dataset instance
        self.dataset = Dataset(self.config, rng)
        # import IPython
        # IPython.embed()
        # Create the model instance
        self.network = Network(self.sess, self.config, self.dataset)
        # Make individual saver instances and summary writers for each module
        self.saver = {}
        self.summary_writer = {}
        self.best_val_loss = {}
        self.best_step = {}
        # Saver (only if there are params)
        for _key in self.network.allparams:
            if len(self.network.allparams[_key]) > 0:
                with tf.variable_scope("saver-{}".format(_key)):
                    self.saver[_key] = tf.train.Saver(
                        self.network.allparams[_key])
        # Summary Writer
        self.summary_writer[self.config.subtask] = tf.summary.FileWriter(
            os.path.join(self.config.logdir, self.config.subtask),
            graph=self.sess.graph)
        # validation loss
        self.best_val_loss[self.config.subtask] = np.inf
        # step for each module
        self.best_step[self.config.subtask] = 0

        # We have everything ready. We finalize and initialie the network here.
        self.sess.run(tf.global_variables_initializer())

        # Enable augmentations and/or force the use of the augmented set
        self.use_aug_rot = 0
        if self.config.augment_rotations:
            self.use_aug_rot = 1
        elif self.config.use_augmented_set:
            self.use_aug_rot = -1

    def run(self):
        # For each module, check we have pre-trained modules and load them
        print("-------------------------------------------------")
        print(" Looking for previous results ")
        print("-------------------------------------------------")
        for _key in ["kp", "ori", "desc", "joint"]:
            restore_network(self, _key)

        print("-------------------------------------------------")
        print(" Training ")
        print("-------------------------------------------------")

        subtask = self.config.subtask
        batch_size = self.config.batch_size
        for step in trange(int(self.best_step[subtask]),
                           int(self.config.max_step),
                           desc="Subtask = {}".format(subtask),
                           ncols=self.config.tqdm_width):
            # ----------------------------------------
            # Forward pass: Note that we only compute the loss in the forward
            # pass. We don't do summary writing or saving
            fw_data = []
            fw_loss = []
            batches = self.hardmine_scheduler(self.config, step)
            for num_cur in batches:
                cur_data = self.dataset.next_batch(task="train",
                                                   subtask=subtask,
                                                   batch_size=num_cur,
                                                   aug_rot=self.use_aug_rot)
                cur_loss = self.network.forward(subtask, cur_data)
                # Sanity check
                if min(cur_loss) < 0:
                    raise RuntimeError('Negative loss while mining?')
                # Data may contain empty (zero-value) samples: set loss to zero
                if num_cur < batch_size:
                    cur_loss[num_cur - batch_size:] = 0
                fw_data.append(cur_data)
                fw_loss.append(cur_loss)
            # Fill a single batch with hardest
            if len(batches) > 1:
                cur_data = get_hard_batch(fw_loss, fw_data)
            # ----------------------------------------
            # Backward pass: Note that the backward pass returns summary only
            # when it is asked. Also, we manually keep note of step here, and
            # not use the tensorflow version. This is to simplify the migration
            # to another framework, if needed.
            do_validation = step % self.config.validation_interval == 0
            cur_summary = self.network.backward(subtask,
                                                cur_data,
                                                provide_summary=do_validation)
            if do_validation and cur_summary is not None:
                # Make sure we have the summary data
                assert cur_summary is not None
                # Write training summary
                self.summary_writer[subtask].add_summary(cur_summary, step)
                # Do multiple rounds of validation
                cur_val_loss = np.zeros(self.config.validation_rounds)
                for _val_round in xrange(self.config.validation_rounds):
                    # Fetch validation data
                    cur_data = self.dataset.next_batch(
                        task="valid",
                        subtask=subtask,
                        batch_size=batch_size,
                        aug_rot=self.use_aug_rot)
                    # Perform validation of the model using validation data
                    cur_val_loss[_val_round] = self.network.validate(
                        subtask, cur_data)
                cur_val_loss = np.mean(cur_val_loss)
                # Inject validation result to summary
                summaries = [
                    tf.Summary.Value(
                        tag="validation/err-{}".format(subtask),
                        simple_value=cur_val_loss,
                    )
                ]
                self.summary_writer[subtask].add_summary(
                    tf.Summary(value=summaries), step)
                # Flush the writer
                self.summary_writer[subtask].flush()

                # TODO: Repeat without augmentation if necessary
                # ...

                if cur_val_loss < self.best_val_loss[subtask]:
                    self.best_val_loss[subtask] = cur_val_loss
                    self.best_step[subtask] = step
                    save_network(self, subtask)

    def hardmine_scheduler(self, config, step, recursion=True):
        """The hard mining scheduler.

        Modes ("--mining-sched"):
            "none": no mining.
            "step": increase one batch at a time.
            "smooth": increase one sample at a time, filling the rest of the
            batch with zeros if necessary.

        Returns a list with the number of samples for every batch.
        """

        sched = config.mining_sched
        if sched == 'none':
            return [config.batch_size]
        elif sched not in ['step', 'smooth']:
            raise RuntimeError('Unknown scheduler')

        # Nothing to do if mining_step is not defined
        if config.mining_step <= 0:
            return [config.batch_size]

        max_batches = config.mining_ceil if config.mining_ceil > 0 else 32
        num_samples = int(
            round(config.batch_size *
                  (config.mining_base + step / config.mining_step)))
        if num_samples > max_batches * config.batch_size:
            # Limit has been reached
            batches = [config.batch_size] * max_batches
        else:
            batches = [config.batch_size] * int(
                num_samples // config.batch_size)
            # Do the remainder on the last batch
            remainder = num_samples % config.batch_size
            if remainder > 0:
                # 'smooth': add remainder to the last batch
                if sched == 'smooth':
                    batches[-1] += remainder
                # 'step': add a full batch when the remainder goes above 50%
                elif sched == 'step' and remainder >= config.batch_size / 2:
                    batches += [config.batch_size]

        # Feedback
        if recursion and step > 0:
            prev = self.hardmine_scheduler(config, step - 1, recursion=False)
            if sum(prev) < sum(batches):
                print(('\n[{}] Mining: increasing number of samples: ' +
                       '{} -> {} (batches: {} -> {}, last batch size: {})'
                       ).format(config.subtask, sum(prev), sum(batches),
                                len(prev), len(batches), batches[-1]))

        return batches
Exemplo n.º 5
0
class Tester(object):
    """The Tester Class

    LATER: Clean up unecessary dictionaries
    LATER: Make a superclass for Tester and Trainer

    """
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng

        # Open a tensorflow session. I like keeping things simple, so I don't
        # use a supervisor. I'm just going to do everything manually. I also
        # will just allow the gpu memory to grow
        tfconfig = tf.ConfigProto()
        tfconfig.gpu_options.allow_growth = True
        self.sess = tf.Session(config=tfconfig)

        # Create the dataset instance
        self.dataset = Dataset(self.config, rng)
        # Retrieve mean/std (yes it is hacky)
        logdir = os.path.join(self.config.logdir, self.config.subtask)
        if os.path.exists(os.path.join(logdir, "mean.h5")):
            training_mean = loadh5(os.path.join(logdir, "mean.h5"))
            training_std = loadh5(os.path.join(logdir, "std.h5"))
            print("[{}] Loaded input normalizers for testing".format(
                self.config.subtask))

            # Create the model instance
            self.network = Network(self.sess, self.config, self.dataset, {
                'mean': training_mean,
                'std': training_std
            })
        else:
            self.network = Network(self.sess, self.config, self.dataset)
        # Make individual saver instances for each module.
        self.saver = {}
        self.best_val_loss = {}
        self.best_step = {}
        # Create the saver instance for both joint and the current subtask
        for _key in ["joint", self.config.subtask]:
            self.saver[_key] = tf.train.Saver(self.network.allparams[_key])

        # We have everything ready. We finalize and initialie the network here.
        self.sess.run(tf.global_variables_initializer())

    def run(self):

        subtask = self.config.subtask

        # Load the network weights for the module of interest
        print("-------------------------------------------------")
        print(" Loading Trained Network ")
        print("-------------------------------------------------")
        # Try loading the joint version, and then fall back to the current task
        # silently if failed.
        try:
            restore_res = restore_network(self, "joint")
        except:
            pass
        if not restore_res:
            restore_res = restore_network(self, subtask)
        if not restore_res:
            raise RuntimeError("Could not load network weights!")

        # Run the appropriate compute function
        print("-------------------------------------------------")
        print(" Testing ")
        print("-------------------------------------------------")

        eval("self._compute_{}()".format(subtask))

    def _compute_kp(self):
        """Compute Keypoints.

        LATER: Clean up code

        """

        total_time = 0.0

        # Read image
        image_color, image_gray, load_prep_time = self.dataset.load_image()

        # check size
        image_height = image_gray.shape[0]
        image_width = image_gray.shape[1]

        # Multiscale Testing
        scl_intv = self.config.test_scl_intv
        # min_scale_log2 = 1  # min scale = 2
        # max_scale_log2 = 4  # max scale = 16
        min_scale_log2 = self.config.test_min_scale_log2
        max_scale_log2 = self.config.test_max_scale_log2
        # Test starting with double scale if small image
        min_hw = np.min(image_gray.shape[:2])
        # for the case of testing on same scale, do not double scale
        if min_hw <= 1600 and min_scale_log2 != max_scale_log2:
            print("INFO: Testing double scale")
            min_scale_log2 -= 1
        # range of scales to check
        num_division = (max_scale_log2 - min_scale_log2) * (scl_intv + 1) + 1
        scales_to_test = 2**np.linspace(min_scale_log2, max_scale_log2,
                                        num_division)

        # convert scale to image resizes
        resize_to_test = ((float(self.config.kp_input_size - 1) / 2.0) /
                          (get_ratio_scale(self.config) * scales_to_test))

        # check if resize is valid
        min_hw_after_resize = resize_to_test * np.min(image_gray.shape[:2])
        is_resize_valid = min_hw_after_resize > self.config.kp_filter_size + 1

        # if there are invalid scales and resizes
        if not np.prod(is_resize_valid):
            # find first invalid
            # first_invalid = np.where(True - is_resize_valid)[0][0]
            first_invalid = np.where(~is_resize_valid)[0][0]

            # remove scales from testing
            scales_to_test = scales_to_test[:first_invalid]
            resize_to_test = resize_to_test[:first_invalid]

        print('resize to test is {}'.format(resize_to_test))
        print('scales to test is {}'.format(scales_to_test))

        # Run for each scale
        test_res_list = []
        for resize in resize_to_test:

            # resize according to how we extracted patches when training
            new_height = np.cast['int'](np.round(image_height * resize))
            new_width = np.cast['int'](np.round(image_width * resize))
            start_time = time.clock()
            image = cv2.resize(image_gray, (new_width, new_height))
            end_time = time.clock()
            resize_time = (end_time - start_time) * 1000.0
            print("Time taken to resize image is {}ms".format(resize_time))
            total_time += resize_time

            # run test
            # LATER: Compatibility with the previous implementations
            start_time = time.clock()

            # Run the network to get the scoremap (the valid region only)
            scoremap = None
            if self.config.test_kp_use_tensorflow:
                scoremap = self.network.test(
                    self.config.subtask,
                    image.reshape(1, new_height, new_width, 1)).squeeze()
            else:
                # OpenCV Version
                raise NotImplementedError("TODO: Implement OpenCV Version")

            end_time = time.clock()
            compute_time = (end_time - start_time) * 1000.0
            print("Time taken for image size {}"
                  " is {} milliseconds".format(image.shape, compute_time))

            total_time += compute_time

            # pad invalid regions and add to list
            start_time = time.clock()
            test_res_list.append(
                np.pad(scoremap,
                       int((self.config.kp_filter_size - 1) / 2),
                       mode='constant',
                       constant_values=-np.inf))
            end_time = time.clock()
            pad_time = (end_time - start_time) * 1000.0
            print("Time taken for padding and stacking is {} ms".format(
                pad_time))
            total_time += pad_time

        # ------------------------------------------------------------------------
        # Non-max suppresion and draw.

        # The nonmax suppression implemented here is very very slow. Consider
        # this as just a proof of concept implementation as of now.

        # Standard nearby : nonmax will check approximately the same area as
        # descriptor support region.
        nearby = int(
            np.round((0.5 * (self.config.kp_input_size - 1.0) *
                      float(self.config.desc_input_size) /
                      float(get_patch_size(self.config)))))
        fNearbyRatio = self.config.test_nearby_ratio
        # Multiply by quarter to compensate
        fNearbyRatio *= 0.25
        nearby = int(np.round(nearby * fNearbyRatio))
        nearby = max(nearby, 1)

        nms_intv = self.config.test_nms_intv
        edge_th = self.config.test_edge_th

        print("Performing NMS")
        start_time = time.clock()
        res_list = test_res_list
        # check whether the return result for socre is right
        #        print(res_list[0][400:500,300:400])
        XYZS = get_XYZS_from_res_list(
            res_list,
            resize_to_test,
            scales_to_test,
            nearby,
            edge_th,
            scl_intv,
            nms_intv,
            do_interpolation=True,
        )
        end_time = time.clock()
        XYZS = XYZS[:self.config.test_num_keypoint]

        # For debugging
        # TODO: Remove below
        draw_XYZS_to_img(XYZS, image_color, self.config.test_out_file + '.jpg')

        nms_time = (end_time - start_time) * 1000.0
        print("NMS time is {} ms".format(nms_time))
        total_time += nms_time
        print("Total time for detection is {} ms".format(total_time))
        # if bPrintTime:
        #     # Also print to a file by appending
        #     with open("../timing-code/timing.txt", "a") as timing_file:
        #         print("------ Keypoint Timing ------\n"
        #               "NMS time is {} ms\n"
        #               "Total time is {} ms\n".format(
        #                   nms_time, total_time
        #               ),
        #               file=timing_file)

        # # resize score to original image size
        # res_list = [cv2.resize(score,
        #                        (image_width, image_height),
        #                        interpolation=cv2.INTER_NEAREST)
        #             for score in test_res_list]
        # # make as np array
        # res_scores = np.asarray(res_list)
        # with h5py.File('test/scores.h5', 'w') as score_file:
        #     score_file['score'] = res_scores

        # ------------------------------------------------------------------------
        # Save as keypoint file to be used by the oxford thing
        print("Turning into kp_list")
        kp_list = XYZS2kpList(XYZS)  # note that this is already sorted

        # ------------------------------------------------------------------------
        # LATER: take care of the orientations somehow...
        # # Also compute angles with the SIFT method, since the keypoint
        # # component alone has no orientations.
        # print("Recomputing Orientations")
        # new_kp_list, _ = recomputeOrientation(image_gray, kp_list,
        #                                       bSingleOrientation=True)

        print("Saving to txt")
        saveKpListToTxt(kp_list, None, self.config.test_out_file)

    def _compute_ori(self):
        """Compute Orientations """

        total_time = 0.0

        # Read image
        start_time = time.clock()
        cur_data = self.dataset.load_data()
        end_time = time.clock()
        load_time = (end_time - start_time) * 1000.0
        print("Time taken to load patches is {} ms".format(load_time))
        total_time += load_time

        # -------------------------------------------------------------------------
        # Test using the test function
        start_time = time.clock()
        oris = self._test_multibatch(cur_data)
        end_time = time.clock()
        compute_time = (end_time - start_time) * 1000.0
        print("Time taken to compute is {} ms".format(compute_time))
        total_time += compute_time

        # update keypoints and save as new
        start_time = time.clock()
        kps = cur_data["kps"]
        for idxkp in xrange(len(kps)):
            kps[idxkp][IDX_ANGLE] = oris[idxkp] * 180.0 / np.pi % 360.0
            kps[idxkp] = update_affine(kps[idxkp])
        end_time = time.clock()
        update_time = (end_time - start_time) * 1000.0
        print("Time taken to update is {} ms".format(update_time))
        total_time += update_time
        print("Total time for orientation is {} ms".format(total_time))

        # save as new keypoints
        saveKpListToTxt(kps, self.config.test_kp_file,
                        self.config.test_out_file)

    def _compute_desc(self):
        """Compute Descriptors """

        total_time = 0.0

        # Read image
        start_time = time.clock()
        cur_data = self.dataset.load_data()
        end_time = time.clock()
        load_time = (end_time - start_time) * 1000.0
        print("Time taken to load patches is {} ms".format(load_time))
        total_time += load_time

        # import IPython
        # IPython.embed()

        # -------------------------------------------------------------------------
        # Test using the test function
        start_time = time.clock()
        descs = self._test_multibatch(cur_data)
        end_time = time.clock()
        compute_time = (end_time - start_time) * 1000.0
        print("Time taken to compute is {} ms".format(compute_time))
        total_time += compute_time
        print("Total time for descriptor is {} ms".format(total_time))

        # Overwrite angle
        kps = cur_data["kps"].copy()
        kps[:, 3] = cur_data["angle"][:, 0]

        # Save as h5 file
        save_dict = {}
        # save_dict['keypoints'] = cur_data["kps"]
        save_dict['keypoints'] = kps
        save_dict['descriptors'] = descs

        saveh5(save_dict, self.config.test_out_file)

    def _test_multibatch(self, cur_data):
        """A sub test routine.

        We do this since the spatial transformer implementation in tensorflow
        does not like undetermined batch sizes.

        LATER: Bypass the spatial transformer...somehow
        LATER: Fix the multibatch testing

        """

        subtask = self.config.subtask
        batch_size = self.config.batch_size
        num_patch = len(cur_data["patch"])
        num_batch = int(np.ceil(float(num_patch) / float(batch_size)))
        # Initialize the batch items
        cur_batch = {}
        for _key in cur_data:
            cur_batch[_key] = np.zeros_like(cur_data[_key][:batch_size])

        # Do muiltiple times
        res = []
        for _idx_batch in xrange(num_batch):
            # start of the batch
            bs = _idx_batch * batch_size
            # end of the batch
            be = min(num_patch, (_idx_batch + 1) * batch_size)
            # number of elements in batch
            bn = be - bs
            for _key in cur_data:
                cur_batch[_key][:bn] = cur_data[_key][bs:be]
            cur_res = self.network.test(subtask, cur_batch).squeeze()[:bn]
            # Append
            res.append(cur_res)

        return np.concatenate(res, axis=0)
Exemplo n.º 6
0
class ImportGraph(object):
    def __init__(self, config, subtask, dataset, tfconfig):
        logdir = os.path.join(config.logdir, subtask)
        self.config = copy.deepcopy(config)
        self.config.subtask = subtask
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=tfconfig)
        with self.graph.as_default():
            if os.path.exists(os.path.join(logdir, "mean.h5")):
                training_mean = loadh5(os.path.join(logdir, "mean.h5"))
                training_std = loadh5(os.path.join(logdir, "std.h5"))
                print("[{}] Loaded input normalizers for testing".format(
                    subtask))
    
                # Create the model instance
                self.network = Network(self.sess, self.config, dataset, {
                                       'mean': training_mean, 'std': training_std})
            else:
                self.network = Network(self.sess, self.config, dataset)
    
            self.saver = {}
            self.best_val_loss = {}
            self.best_step = {}
            # Create the saver instance for both joint and the current subtask
            for _key in ["joint", subtask]:
                self.saver[_key] = tf.train.Saver(self.network.allparams[_key])

            # We have everything ready. We finalize and initialie the network here.
        
            self.sess.run(tf.global_variables_initializer())
            restore_res = self.restore_network()
            if not restore_res:
                raise RuntimeError("Could not load network weights!")
            
    def restore_network(self):
        """Restore training status"""
    
        # Skip if there's no saver of this subtask
        if self.config.subtask not in self.saver:
            return False
    
        is_loaded = False
    
        # Check if pretrain weight file is specified
        predir = getattr(self.config, "pretrained_{}".format(self.config.subtask))
        # Try loading the old weights
        is_loaded += self.load_legacy_network(predir)
        # Try loading the tensorflow weights
        is_loaded += self.load_network(predir)
    
        # Load network using tensorflow saver
        logdir = os.path.join(self.config.logdir, self.config.subtask)
        is_loaded += self.load_network(logdir)
    
        return is_loaded

    def load_legacy_network(self, load_dir):
        """Load function for our old framework"""
    
        print("[{}] Checking if old pre-trained weights exists in {}"
              "".format(self.config.subtask, load_dir))
        model_file = os.path.join(load_dir, "model.h5")
        norm_file = os.path.join(load_dir, "norm.h5")
        base_file = os.path.join(load_dir, "base.h5")
    
        if os.path.exists(model_file) and os.path.exists(norm_file) and \
           os.path.exists(base_file):
            model = loadh5(model_file)
            norm = loadh5(norm_file)
            base = loadh5(base_file)
            # Load the input normalization parameters.
            with self.graph.as_default():
                self.network.mean["kp"] = float(norm["mean_x"])
                self.network.mean["ori"] = float(norm["mean_x"])
                self.network.mean["desc"] = float(base["patch-mean"])
                self.network.std["kp"] = float(norm["std_x"])
                self.network.std["ori"] = float(norm["std_x"])
                self.network.std["desc"] = float(base["patch-std"])
                # Load weights for the component
                self.network.legacy_load_func[self.config.subtask](self.sess, model)
                print("[{}] Loaded previously trained weights".format(self.config.subtask))
            return True
        else:
            print("[{}] No pretrained weights from the old framework"
                  "".format(self.config.subtask))
            return False
    
    def load_network(self, load_dir):
        """Load function for our new framework"""
    
        print("[{}] Checking if previous Tensorflow run exists in {}"
              "".format(self.config.subtask, load_dir))
        latest_checkpoint = tf.train.latest_checkpoint(load_dir)
        if latest_checkpoint is not None:
            # Load parameters
            with self.graph.as_default():
                self.saver[self.config.subtask].restore(
                    self.sess,
                    latest_checkpoint
                )
                print("[{}] Loaded previously trained weights".format(self.config.subtask))
                # Save mean std (falls back to default if non-existent)
                if os.path.exists(os.path.join(load_dir, "mean.h5")):
                    self.network.mean = loadh5(os.path.join(load_dir, "mean.h5"))
                    self.network.std = loadh5(os.path.join(load_dir, "std.h5"))
                    print("[{}] Loaded input normalizers".format(self.config.subtask))
                # Load best validation result
                self.best_val_loss[self.config.subtask] = loadh5(
                    os.path.join(load_dir, best_val_loss_filename)
                )[self.config.subtask]
                print("[{}] Loaded best validation result = {}".format(
                    self.config.subtask,self.best_val_loss[self.config.subtask]))
                # Load best validation result
                self.best_step[self.config.subtask] = loadh5(
                    os.path.join(load_dir, best_step_filename)
                )[self.config.subtask]
                print("[{}] Loaded best step = {}".format(
                    self.config.subtask, self.best_step[self.config.subtask]))
            return True
    
        else:
            print("[{}] No previous Tensorflow result".format(self.config.subtask))
            return False
        
    def test_squeeze(self, data):
        with self.graph.as_default():
            return self.network.test(
                    self.config.subtask,
                    data
                ).squeeze()