def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 task.prepare( num_dataloader_workers=self.num_dataloader_workers, pin_memory=pin_memory, use_gpu=self.use_gpu, dataloader_mp_context=self.dataloader_mp_context, ) assert isinstance(task, ClassyTask) # make sure all the workers start training at the same time # this helps catch hangs which would have happened elsewhere barrier() local_variables = {} task.on_start(local_variables) while not task.done_training(): task.on_phase_start(local_variables) while True: try: task.step(self.use_gpu, local_variables) except StopIteration: break task.on_phase_end(local_variables) task.on_end(local_variables)
def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ task.prepare() assert isinstance(task, ClassyTask) # make sure all the workers start training at the same time # this helps catch hangs which would have happened elsewhere barrier() task.on_start() while not task.done_training(): task.on_phase_start() while True: try: task.step() except StopIteration: break task.on_phase_end() task.on_end()
def on_phase_end(self, local_variables): logging.info("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.info("...meters synced") barrier() self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 task.prepare( num_dataloader_workers=self.num_dataloader_workers, pin_memory=pin_memory, use_gpu=self.use_gpu, dataloader_mp_context=self.dataloader_mp_context, ) assert isinstance(task, ClassyTask) if is_distributed_training_run(): task.init_distributed_data_parallel_model() local_variables = {} task.run_hooks(local_variables, ClassyHookFunctions.on_start.name) best_acc = { 'top1_acc': 0, 'top1_epoch': 0, 'top5_acc': 0, 'top5_epoch': 0 } epoch = 0 while not task.done_training(): task.advance_phase() # Start phase hooks task.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name) while True: # Process next sample try: task.train_step(self.use_gpu, local_variables) except StopIteration: break logging.info("Syncing meters on phase end...") for meter in task.meters: meter.sync_state() logging.info("...meters synced") barrier() meter = task.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name) if meter is not None: if meter[0].value['top_1'] > best_acc['top1_acc']: best_acc['top1_acc'] = meter[0].value['top_1'] best_acc['top5_acc'] = meter[0].value['top_5'] best_acc['top1_epoch'] = epoch best_acc['top5_epoch'] = epoch epoch += 1 task.run_hooks(local_variables, ClassyHookFunctions.on_end.name) return best_acc
def on_phase_end(self): self.log_phase_end("train") logging.info("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.info("...meters synced") barrier() for hook in self.hooks: hook.on_phase_end(self) self.perf_log = [] self.log_phase_end("total")
def _get_split_features(self, feat_names: List[str], cfg: AttrDict, task: ClassyTask): task.model.eval() logging.info("Model set to eval mode during feature extraction...") out_features, out_targets = {}, {} for layer in feat_names: out_features[layer], out_targets[layer] = {}, {} while True: try: sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "input": torch.cat(sample["data"]).cuda(non_blocking=True), "target": torch.cat(sample["label"]).cpu().numpy(), "inds": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): features = task.model(input_sample["input"]) flat_features_list = self._flatten_features_list(features) num_images = input_sample["inds"].shape[0] for num, layer in enumerate(feat_names): feature = flat_features_list[num].cpu().numpy() targets = input_sample["target"] for idx in range(num_images): index = input_sample["inds"][idx] if not (index in out_features[layer]): out_targets[layer][index] = targets[ idx].reshape(-1) out_features[layer][index] = feature[idx] except StopIteration: break barrier() output = {} for layer in feat_names: out_features[layer] = dict(sorted(out_features[layer].items())) out_targets[layer] = dict(sorted(out_targets[layer].items())) feats = np.array(list(out_features[layer].values())) N = feats.shape[0] output[layer] = { "features": feats.reshape(N, -1), "targets": np.array(list(out_targets[layer].values())), "inds": np.array(list(out_features[layer].keys())), } return output
def _run_step(self, state, local_variables, use_gpu): # Check for training complete but only terminate when the last phase is done if state.task.done_training() and state.advance_to_next_phase: raise StopIteration if state.advance_to_next_phase: state.task.advance_phase() # Start phase hooks state.task.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name) state.advance_to_next_phase = False # Process one train step try: if state.skip_current_phase: state.advance_to_next_phase = True state.skip_current_phase = False # Reset flag else: state.task.train_step(use_gpu, local_variables) except StopIteration: state.advance_to_next_phase = True if state.advance_to_next_phase: logging.info("Syncing meters on phase end...") for meter in state.task.meters: meter.sync_state() logging.info("...meters synced") barrier() # Phase complete # NOTE: this is a good time to checkpoint, as it guarantees # that loading from checkpoint will properly advance the phase. state.task.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name) progress_rate = None # using None to signal 'unknown' perf_stats = local_variables.get("perf_stats", None) if perf_stats is not None: batch_time = perf_stats._cuda_stats[ "train_step_total"].smoothed_value if batch_time is not None and batch_time > 0.0: # rate = number of mini-batches per second progress_rate = 1.0 / batch_time progress_stats = self._ClassyWorkerStats(progress_rate) return state, progress_stats
def on_phase_end(self): self.log_phase_end("train") if self.train: self.optimizer.on_epoch(where=self.where) logging.debug("Syncing losses on phase end...") self.synchronize_losses() logging.debug("...losses synced") logging.debug("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.debug("...meters synced") barrier() for hook in self.hooks: hook.on_phase_end(self) self.perf_log = [] self.log_phase_end("total")
def train(self): """ The train workflow. We get the training loop to use (vissl default is standard_train_step) but the user can create their own training loop and specify the name TRAINER.TRAIN_STEP_NAME The training happens: 1. Execute any hooks at the start of training (mostly resets the variable like iteration num phase_num etc) 2. For each epoch (train or test), run the hooks at the start of an epoch. Mostly involves setting things like timer, setting dataloader epoch etc 3. Execute the training loop (1 training iteration) involving forward, loss, backward, optimizer update, metrics collection etc. 4. At the end of epoch, sync meters and execute hooks at the end of phase. Involves things like checkpointing model, logging timers, logging to tensorboard etc """ train_step_fn = get_train_step(self.cfg["TRAINER"]["TRAIN_STEP_NAME"]) self.task.prepare(pin_memory=self.cfg.DATA.PIN_MEMORY) self.task.init_distributed_data_parallel_model() # Find what phase, train_phase_idx, local_iteration_num we are starting from. # Recover it from the checkpoint (if available) task, phase_idx, iteration_num = self._init_training_state( self.cfg, self.task) # Good to go, (re) start training task.run_hooks(SSLClassyHookFunctions.on_start.name) if is_primary(): logging.info("Model is:\n {}".format(task.model)) logging.info("Loss is: {}".format(task.loss)) logging.info("Starting training....") while phase_idx + 1 < len(task.phases): self._advance_phase(task) # advances task.phase_idx phase_idx += 1 iteration_num += 1 task.local_iteration_num = iteration_num # iteration_num=0 at this step task.run_hooks(SSLClassyHookFunctions.on_phase_start.name) while True: try: if self.cfg.MODEL.CUDA_CACHE.CLEAR_CUDA_CACHE and ( iteration_num % self.cfg.MODEL.CUDA_CACHE.CLEAR_FREQ == 0): logging.info( f"Emptying CUDA cache at step count: {iteration_num}" ) torch.cuda.empty_cache() logging.info("CUDA cache cleared") task = train_step_fn(task) iteration_num += 1 task.local_iteration_num = iteration_num task.run_hooks(SSLClassyHookFunctions.on_step.name) except StopIteration: break for meter in task.meters: meter.sync_state() logging.info("Meters synced") barrier() task.run_hooks(SSLClassyHookFunctions.on_phase_end.name) task.run_hooks(SSLClassyHookFunctions.on_end.name) if hasattr(task, "data_iterator"): del task.data_iterator gc.collect() if hasattr(task, "dataloaders"): del task.dataloaders gc.collect()