def _get_cluster_assignment_for_split(self, task: ClassyTask, split: str): task.model.eval() logging.info("Model set to eval mode during feature extraction...") cluster_assignments = {} task.data_iterator = iter(self.task.dataloaders[split.lower()]) while True: try: sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "images": torch.cat(sample["data"]).cuda(non_blocking=True), "indices": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): features = task.model(input_sample["images"]) features = features[0] prototype_score = features[1] prototype_index = prototype_score.argmax(dim=-1) num_images = input_sample["indices"].shape[0] for idx in range(num_images): image_index = input_sample["indices"][idx] cluster_assignments[image_index] = prototype_index[idx].item() except StopIteration: break return cluster_assignments
def _init_training_state(cfg, task: ClassyTask) -> Tuple[ClassyTask, int, int]: """ If a checkpoint is present, recover the current training status. If not initialize everything properly Args: task {ClassyTask}: object consisting of all components a training requires (meters, optimizers, model, loss etc.) Returns: task {ClassyTask}: updated task phase_idx {int}: phase index iteration_num: iteration number """ phase_idx, iteration_num = -1, -1 # Ensure that train loader exists. Will NOT exist if config.TEST_ONLY is True if "train" in task.dataloaders.keys(): loader_key = "train" else: loader_key = "test" task.max_iteration = task.num_train_phases * len( task.dataloaders[loader_key]) if task.checkpoint is not None: phase_idx = task.checkpoint["phase_idx"] task.train_phase_idx = task.checkpoint["train_phase_idx"] task.local_iteration_num = task.checkpoint["iteration_num"] task.iteration = task.checkpoint["iteration"] else: task.iteration = 0 task.local_iteration_num = iteration_num num_iter_in_phase = len(task.dataloaders[loader_key]) num_iter_in_epoch = num_iter_in_phase * task.num_train_phases_per_epoch num_samples = task.num_phase_samples(loader_key) task.start_time = time.time() task.batch_time = [] task.metrics = {} logging.info(f"Training {task.num_epochs} epochs") logging.info(f"One epoch = {num_iter_in_epoch} iterations.") logging.info(f"Total {num_samples} samples in one epoch") if task.num_epochs != task.num_train_phases: logging.info( f"Training a total of {task.num_train_phases} train phases.") logging.info(f"One phase = {num_iter_in_phase} iterations.") logging.info(f"Total {task.max_iteration} iterations for training") return task, phase_idx, task.local_iteration_num
def _advance_phase(self, task: ClassyTask): """ Advance the training phase to the next phase. - Updates the phase number, - resets the meters, - reset losses, - recreates the data iterator and destroys previous iterator - set the model to be in train or eval phase depending on what phase we are in - execute any optimizer update (normally learning rate updates etc at the end of an epoch) """ # reset the meters at the beginning of the epoch for meter in task.meters: meter.reset() # reset the loss history for this epoch task.losses = [] # advance the epoch num to be current task.phase_idx += 1 phase = task.phases[task.phase_idx] task.train = True if phase["train"] else False if task.train: task.train_phase_idx += 1 # get a new data iterator - delete the iterator at the beginning explicitly # so that all dataloader processes are cleaned up phase_type = "train" if phase["train"] else "test" # we are advancing to next epoch, so no need to compute start_iter, # just let it to be 0 inside of recreate_data_iterator. However, if we are just # starting from the resumed training, we want to compute_start_iter # again (if applicable) since we recreate the data iterator and delete # the old ones. compute_start_iter = False if task.checkpoint is not None and task.checkpoint["train_phase_idx"] == ( task.train_phase_idx - 1 ): compute_start_iter = True task.recreate_data_iterator( phase_type, epoch=task.phase_idx, compute_start_iter=compute_start_iter, train_phase_idx=task.train_phase_idx, ) # set the model to train or eval depending on what phase we are in task.model.train(phase["train"]) if task.train and task.train_phase_idx >= 0: task.optimizer.on_epoch(task.where) local_rank, _ = get_machine_local_and_dist_rank() logging.info(f"Phase advanced. Rank: {local_rank}")
def _build_momentum_network(self, task: tasks.ClassyTask) -> None: """ Create the teacher: it is an exponential moving average of the student. """ logging.info("Building momentum encoder") # Same architecture but do not apply stochastic depth # TODO: make drop_path_rate configurable for teacher task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][ "DROP_PATH_RATE"] = 0.0 task.loss.momentum_teacher = build_model(task.config["MODEL"], task.config["OPTIMIZER"]) task.loss.momentum_teacher.to(task.device) # Restore an hypothetical checkpoint if task.loss.checkpoint is not None: task.loss.load_state_dict(task.loss.checkpoint) # Initialize from the model else: task_model = get_no_ddp_model(task.model) teacher_model = get_no_ddp_model(task.loss.momentum_teacher) teacher_model.load_state_dict(task_model.state_dict()) # Setup SyncBN (useful for the XCiT) task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm( task.loss.momentum_teacher) task.loss.momentum_teacher = DistributedDataParallel( task.loss.momentum_teacher, device_ids=[task.device]) # no gradients for teacher model for p in task.loss.momentum_teacher.parameters(): p.requires_grad = False
def _build_momentum_network(self, task: tasks.ClassyTask) -> None: """ Create the teacher: it is an exponential moving average of the student. """ logging.info("Building momentum encoder") # - same architecture but do not apply stochastic depth task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][ "DROP_PATH_RATE"] = 0 task.loss.momentum_teacher = build_model(task.config["MODEL"], task.config["OPTIMIZER"]) task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm( task.loss.momentum_teacher) task.loss.momentum_teacher.to(task.device) if get_world_size() > 1: task.loss.momentum_teacher = init_distributed_data_parallel_model( task.loss.momentum_teacher) # Restore an hypothetical checkpoint if task.loss.checkpoint is not None: task.loss.load_state_dict(task.loss.checkpoint) # Initialize from the model else: task.loss.momentum_teacher.load_state_dict(task.model.state_dict())
def _get_split_features(self, feat_names: List[str], cfg: AttrDict, task: ClassyTask): task.model.eval() logging.info("Model set to eval mode during feature extraction...") out_features, out_targets = {}, {} for layer in feat_names: out_features[layer], out_targets[layer] = {}, {} while True: try: sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "input": torch.cat(sample["data"]).cuda(non_blocking=True), "target": torch.cat(sample["label"]).cpu().numpy(), "inds": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): features = task.model(input_sample["input"]) flat_features_list = self._flatten_features_list(features) num_images = input_sample["inds"].shape[0] for num, layer in enumerate(feat_names): feature = flat_features_list[num].cpu().numpy() targets = input_sample["target"] for idx in range(num_images): index = input_sample["inds"][idx] if not (index in out_features[layer]): out_targets[layer][index] = targets[ idx].reshape(-1) out_features[layer][index] = feature[idx] except StopIteration: break barrier() output = {} for layer in feat_names: out_features[layer] = dict(sorted(out_features[layer].items())) out_targets[layer] = dict(sorted(out_targets[layer].items())) feats = np.array(list(out_features[layer].values())) N = feats.shape[0] output[layer] = { "features": feats.reshape(N, -1), "targets": np.array(list(out_targets[layer].values())), "inds": np.array(list(out_features[layer].keys())), } return output
def _update_momentum_network(self, task: tasks.ClassyTask) -> None: """ EMA update Each teacher parameter becomes a weighted average of its old self and the newest student. """ # Cosine schedule for the teacher momentum m = 1 - 0.5 * (1 - task.loss.loss_config.momentum) * ( math.cos(math.pi * task.iteration / task.max_iteration) + 1) task.additional_log_data["dino_teacher_momentum"] = m task_model = get_no_ddp_model(task.model) teacher_model = get_no_ddp_model(task.loss.momentum_teacher) # EMA update for the teacher parameters for param_q, param_k in zip(task_model.parameters(), teacher_model.parameters()): param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
def update_teacher_temperature(self, task: tasks.ClassyTask) -> None: """ Update the teacher temperature """ if self.teacher_temp_schedule is None: teacher_temp_min = task.loss.loss_config["teacher_temp_min"] teacher_temp_max = task.loss.loss_config["teacher_temp_max"] teacher_temp_warmup_iters = task.loss.loss_config[ "teacher_temp_warmup_iters"] self.teacher_temp_schedule = torch.cat(( torch.linspace(teacher_temp_min, teacher_temp_max, teacher_temp_warmup_iters), torch.ones( max(0, task.max_iteration - teacher_temp_warmup_iters)) * teacher_temp_max, )) teacher_temp = self.teacher_temp_schedule[task.iteration].item() task.loss.teacher_temp = teacher_temp task.additional_log_data["dino_teacher_temp"] = teacher_temp
def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ task.prepare() assert isinstance(task, ClassyTask) # make sure all the workers start training at the same time # this helps catch hangs which would have happened elsewhere barrier() task.on_start() while not task.done_training(): task.on_phase_start() while True: try: task.step() except StopIteration: break task.on_phase_end() task.on_end()
def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 task.prepare( num_dataloader_workers=self.num_dataloader_workers, pin_memory=pin_memory, use_gpu=self.use_gpu, dataloader_mp_context=self.dataloader_mp_context, ) assert isinstance(task, ClassyTask) # make sure all the workers start training at the same time # this helps catch hangs which would have happened elsewhere barrier() local_variables = {} task.on_start(local_variables) while not task.done_training(): task.on_phase_start(local_variables) while True: try: task.step(self.use_gpu, local_variables) except StopIteration: break task.on_phase_end(local_variables) task.on_end(local_variables)
def train(self, task: ClassyTask): task.hooks = task.hooks + [LimitedPhaseHook(self.num_phases)] try: super().train(task) except LimitedPhaseException: pass
def train(self, task: ClassyTask): """Runs training phases, phases are generated from the config. Args: task: Task to be used in training. It should contain everything that is needed for training """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 task.prepare( num_dataloader_workers=self.num_dataloader_workers, pin_memory=pin_memory, use_gpu=self.use_gpu, dataloader_mp_context=self.dataloader_mp_context, ) assert isinstance(task, ClassyTask) if is_distributed_training_run(): task.init_distributed_data_parallel_model() local_variables = {} task.run_hooks(local_variables, ClassyHookFunctions.on_start.name) best_acc = { 'top1_acc': 0, 'top1_epoch': 0, 'top5_acc': 0, 'top5_epoch': 0 } epoch = 0 while not task.done_training(): task.advance_phase() # Start phase hooks task.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name) while True: # Process next sample try: task.train_step(self.use_gpu, local_variables) except StopIteration: break logging.info("Syncing meters on phase end...") for meter in task.meters: meter.sync_state() logging.info("...meters synced") barrier() meter = task.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name) if meter is not None: if meter[0].value['top_1'] > best_acc['top1_acc']: best_acc['top1_acc'] = meter[0].value['top_1'] best_acc['top5_acc'] = meter[0].value['top_5'] best_acc['top1_epoch'] = epoch best_acc['top5_epoch'] = epoch epoch += 1 task.run_hooks(local_variables, ClassyHookFunctions.on_end.name) return best_acc
def _get_cluster_assignment_for_split( self, task: ClassyTask, split: str, output_folder: str ): task.model.eval() logging.info("Model set to eval mode during feature extraction...") dist_rank = torch.distributed.get_rank() cluster_assignments = {} soft_cluster_assignments = {} image_indices = [] chunk_index, buffer_size = 0, 0 task.data_iterator = iter(self.task.dataloaders[split.lower()]) while True: try: sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "images": torch.cat(sample["data"]).cuda(non_blocking=True), "indices": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): outputs = task.model(input_sample["images"]) prototype_score = outputs[0][1] prototype_index = prototype_score.argmax(dim=-1) num_images = input_sample["indices"].shape[0] buffer_size += num_images for idx in range(num_images): image_index = input_sample["indices"][idx] cluster_assignments[image_index] = prototype_index[idx].item() soft_cluster_assignments[ image_index ] = prototype_score.cpu().numpy() image_indices.append(image_index) if buffer_size >= self.cfg.EXTRACT_FEATURES.CHUNK_THRESHOLD >= 0: self._save_extracted_prototypes( soft_assignments=soft_cluster_assignments, out_indices=image_indices, dist_rank=dist_rank, chunk_index=chunk_index, split=split, output_folder=output_folder, ) soft_cluster_assignments.clear() image_indices.clear() chunk_index += 1 buffer_size = 0 except StopIteration: if buffer_size: self._save_extracted_prototypes( soft_assignments=soft_cluster_assignments, out_indices=image_indices, dist_rank=dist_rank, chunk_index=chunk_index, split=split, output_folder=output_folder, ) break return cluster_assignments
def _extract_split_features( self, feat_names: List[str], task: ClassyTask, split_name: str, output_folder: str, ): task.model.eval() logging.info("Model set to eval mode during feature extraction...") dist_rank = torch.distributed.get_rank() out_features, out_targets = {}, {} for feat_name in feat_names: out_features[feat_name], out_targets[feat_name] = {}, {} chunk_index, feature_buffer_size, count = 0, 0, 0 while True: if count % 100 == 0: logging.info(f"Feature extraction iteration: {count}") try: sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "input": torch.cat(sample["data"]).cuda(non_blocking=True), "target": torch.cat(sample["label"]).cpu().numpy(), "inds": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): features = task.model(input_sample["input"]) flat_features_list = self._flatten_features_list(features) num_images = input_sample["inds"].shape[0] feature_buffer_size += num_images for num, feat_name in enumerate(feat_names): feature = flat_features_list[num].cpu().numpy() targets = input_sample["target"] for idx in range(num_images): index = input_sample["inds"][idx] out_features[feat_name][index] = feature[idx] out_targets[feat_name][index] = targets[idx].reshape(-1) if ( feature_buffer_size >= self.cfg.EXTRACT_FEATURES.CHUNK_THRESHOLD >= 0 ): self._save_extracted_features( features=out_features, targets=out_targets, dist_rank=dist_rank, chunk_index=chunk_index, split=split_name, output_folder=output_folder, ) for layer_name in out_features.keys(): out_features[layer_name].clear() chunk_index += 1 feature_buffer_size = 0 except StopIteration: self._save_extracted_features( features=out_features, targets=out_targets, dist_rank=dist_rank, chunk_index=chunk_index, split=split_name, output_folder=output_folder, ) break count += 1
def _extract_split_label_predictions( self, feat_names: List[str], task: ClassyTask, split_name: str, output_folder: str, ): task.model.eval() logging.info("Model set to eval mode during feature extraction...") dist_rank = torch.distributed.get_rank() feat_names = self._to_unique_feature_names(feat_names) out_predictions, out_targets, out_scores = {}, {}, {} for feat_name in feat_names: out_predictions[feat_name] = {} out_scores[feat_name] = {} out_targets[feat_name] = {} assert len(task.meters) > 0, "Please specify one meter to extract predictions" assert len(task.meters) == 1, "Please use only one meter to extract predictions" for meter in task.meters: assert hasattr( meter, "get_predictions" ), f"Meter {meter.name} doesn't implement get_predictions function" dataset = task.datasets[split_name.lower()] all_image_paths = dataset.get_image_paths() assert ( len(all_image_paths) == 1 ), "Multi-dataset not supported yet for label predictions." all_image_paths = all_image_paths[0] for count in itertools.count(start=0, step=1): try: if count % 100 == 0: logging.info(f"Label prediction extraction iteration: {count}") sample = next(task.data_iterator) assert isinstance(sample, dict) assert "data_idx" in sample, "Indices not passed" input_sample = { "input": torch.cat(sample["data"]).cuda(non_blocking=True), "target": torch.cat(sample["label"]).cpu().numpy(), "inds": torch.cat(sample["data_idx"]).cpu().numpy(), } with torch.no_grad(): # Send the input sample to the model, tracking also the # last batch for the hooks to refer to task.last_batch = SimpleNamespace() model_output = task.model(input_sample["input"]) task.last_batch.sample = input_sample task.last_batch.model_output = model_output # Run hooks on forward pass task.run_hooks(SSLClassyHookFunctions.on_forward.name) # get the model predictions using the meter if isinstance(model_output, list): model_output_cpu = [x.cpu() for x in model_output] else: model_output_cpu = model_output.cpu() for meter in task.meters: meter.update( model_output_cpu, sample["label"][0].detach().cpu() ) predictions, pred_scores = task.meters[0].get_predictions( model_output_cpu ) num_images = input_sample["inds"].shape[0] for num, layer_name in enumerate(feat_names): pred = predictions[num] score = pred_scores[num] targets = input_sample["target"] for idx in range(num_images): index = input_sample["inds"][idx] if not (index in out_predictions[layer_name]): out_targets[layer_name][index] = targets[idx].reshape( -1 ) out_predictions[layer_name][index] = pred[idx] out_scores[layer_name][index] = score[idx] except StopIteration: break # print the meters results. This can offer a validation # of the extracted predictions. self._sync_and_print_meters(task) # save the predictions, targets and image indices now self._save_extracted_label_predictions( all_image_paths=all_image_paths, predictions=out_predictions, confidence_scores=out_scores, targets=out_targets, dist_rank=dist_rank, split=split_name, output_folder=output_folder, )