def run(self): subtask = self.config.subtask # Load the network weights for the module of interest print("-------------------------------------------------") print(" Loading Trained Network ") print("-------------------------------------------------") # Try loading the joint version, and then fall back to the current task # silently if failed. try: restore_res = restore_network(self, "joint") except: pass if not restore_res: restore_res = restore_network(self, subtask) if not restore_res: raise RuntimeError("Could not load network weights!") # Run the appropriate compute function print("-------------------------------------------------") print(" Testing ") print("-------------------------------------------------") eval("self._compute_{}()".format(subtask))
def run(self): # For each module, check we have pre-trained modules and load them print("-------------------------------------------------") print(" Looking for previous results ") print("-------------------------------------------------") for _key in ["kp", "ori", "desc", "joint"]: restore_network(self, _key) print("-------------------------------------------------") print(" Training ") print("-------------------------------------------------") subtask = self.config.subtask batch_size = self.config.batch_size for step in trange(int(self.best_step[subtask]), int(self.config.max_step), desc="Subtask = {}".format(subtask), ncols=self.config.tqdm_width): # ---------------------------------------- # Forward pass: Note that we only compute the loss in the forward # pass. We don't do summary writing or saving fw_data = [] fw_loss = [] batches = self.hardmine_scheduler(self.config, step) for num_cur in batches: cur_data = self.dataset.next_batch(task="train", subtask=subtask, batch_size=num_cur, aug_rot=self.use_aug_rot) cur_loss = self.network.forward(subtask, cur_data) # Sanity check if min(cur_loss) < 0: raise RuntimeError('Negative loss while mining?') # Data may contain empty (zero-value) samples: set loss to zero if num_cur < batch_size: cur_loss[num_cur - batch_size:] = 0 fw_data.append(cur_data) fw_loss.append(cur_loss) # Fill a single batch with hardest if len(batches) > 1: cur_data = get_hard_batch(fw_loss, fw_data) # ---------------------------------------- # Backward pass: Note that the backward pass returns summary only # when it is asked. Also, we manually keep note of step here, and # not use the tensorflow version. This is to simplify the migration # to another framework, if needed. do_validation = step % self.config.validation_interval == 0 cur_summary = self.network.backward(subtask, cur_data, provide_summary=do_validation) if do_validation and cur_summary is not None: # Make sure we have the summary data assert cur_summary is not None # Write training summary self.summary_writer[subtask].add_summary(cur_summary, step) # Do multiple rounds of validation cur_val_loss = np.zeros(self.config.validation_rounds) for _val_round in xrange(self.config.validation_rounds): # Fetch validation data cur_data = self.dataset.next_batch( task="valid", subtask=subtask, batch_size=batch_size, aug_rot=self.use_aug_rot) # Perform validation of the model using validation data cur_val_loss[_val_round] = self.network.validate( subtask, cur_data) cur_val_loss = np.mean(cur_val_loss) # Inject validation result to summary summaries = [ tf.Summary.Value( tag="validation/err-{}".format(subtask), simple_value=cur_val_loss, ) ] self.summary_writer[subtask].add_summary( tf.Summary(value=summaries), step) # Flush the writer self.summary_writer[subtask].flush() # TODO: Repeat without augmentation if necessary # ... if cur_val_loss < self.best_val_loss[subtask]: self.best_val_loss[subtask] = cur_val_loss self.best_step[subtask] = step save_network(self, subtask)