示例#1
0
    def run(self):

        subtask = self.config.subtask

        # Load the network weights for the module of interest
        print("-------------------------------------------------")
        print(" Loading Trained Network ")
        print("-------------------------------------------------")
        # Try loading the joint version, and then fall back to the current task
        # silently if failed.
        try:
            restore_res = restore_network(self, "joint")
        except:
            pass
        if not restore_res:
            restore_res = restore_network(self, subtask)
        if not restore_res:
            raise RuntimeError("Could not load network weights!")

        # Run the appropriate compute function
        print("-------------------------------------------------")
        print(" Testing ")
        print("-------------------------------------------------")

        eval("self._compute_{}()".format(subtask))
示例#2
0
    def run(self):
        # For each module, check we have pre-trained modules and load them
        print("-------------------------------------------------")
        print(" Looking for previous results ")
        print("-------------------------------------------------")
        for _key in ["kp", "ori", "desc", "joint"]:
            restore_network(self, _key)

        print("-------------------------------------------------")
        print(" Training ")
        print("-------------------------------------------------")

        subtask = self.config.subtask
        batch_size = self.config.batch_size
        for step in trange(int(self.best_step[subtask]),
                           int(self.config.max_step),
                           desc="Subtask = {}".format(subtask),
                           ncols=self.config.tqdm_width):
            # ----------------------------------------
            # Forward pass: Note that we only compute the loss in the forward
            # pass. We don't do summary writing or saving
            fw_data = []
            fw_loss = []
            batches = self.hardmine_scheduler(self.config, step)
            for num_cur in batches:
                cur_data = self.dataset.next_batch(task="train",
                                                   subtask=subtask,
                                                   batch_size=num_cur,
                                                   aug_rot=self.use_aug_rot)
                cur_loss = self.network.forward(subtask, cur_data)
                # Sanity check
                if min(cur_loss) < 0:
                    raise RuntimeError('Negative loss while mining?')
                # Data may contain empty (zero-value) samples: set loss to zero
                if num_cur < batch_size:
                    cur_loss[num_cur - batch_size:] = 0
                fw_data.append(cur_data)
                fw_loss.append(cur_loss)
            # Fill a single batch with hardest
            if len(batches) > 1:
                cur_data = get_hard_batch(fw_loss, fw_data)
            # ----------------------------------------
            # Backward pass: Note that the backward pass returns summary only
            # when it is asked. Also, we manually keep note of step here, and
            # not use the tensorflow version. This is to simplify the migration
            # to another framework, if needed.
            do_validation = step % self.config.validation_interval == 0
            cur_summary = self.network.backward(subtask,
                                                cur_data,
                                                provide_summary=do_validation)
            if do_validation and cur_summary is not None:
                # Make sure we have the summary data
                assert cur_summary is not None
                # Write training summary
                self.summary_writer[subtask].add_summary(cur_summary, step)
                # Do multiple rounds of validation
                cur_val_loss = np.zeros(self.config.validation_rounds)
                for _val_round in xrange(self.config.validation_rounds):
                    # Fetch validation data
                    cur_data = self.dataset.next_batch(
                        task="valid",
                        subtask=subtask,
                        batch_size=batch_size,
                        aug_rot=self.use_aug_rot)
                    # Perform validation of the model using validation data
                    cur_val_loss[_val_round] = self.network.validate(
                        subtask, cur_data)
                cur_val_loss = np.mean(cur_val_loss)
                # Inject validation result to summary
                summaries = [
                    tf.Summary.Value(
                        tag="validation/err-{}".format(subtask),
                        simple_value=cur_val_loss,
                    )
                ]
                self.summary_writer[subtask].add_summary(
                    tf.Summary(value=summaries), step)
                # Flush the writer
                self.summary_writer[subtask].flush()

                # TODO: Repeat without augmentation if necessary
                # ...

                if cur_val_loss < self.best_val_loss[subtask]:
                    self.best_val_loss[subtask] = cur_val_loss
                    self.best_step[subtask] = step
                    save_network(self, subtask)