示例#1
0
    def _get_cluster_assignment_for_split(self, task: ClassyTask, split: str):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")

        cluster_assignments = {}
        task.data_iterator = iter(self.task.dataloaders[split.lower()])
        while True:
            try:
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"

                input_sample = {
                    "images": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "indices": torch.cat(sample["data_idx"]).cpu().numpy(),
                }

                with torch.no_grad():
                    features = task.model(input_sample["images"])
                    features = features[0]
                    prototype_score = features[1]
                    prototype_index = prototype_score.argmax(dim=-1)
                    num_images = input_sample["indices"].shape[0]
                    for idx in range(num_images):
                        image_index = input_sample["indices"][idx]
                        cluster_assignments[image_index] = prototype_index[idx].item()
            except StopIteration:
                break
        return cluster_assignments
示例#2
0
    def _init_training_state(cfg,
                             task: ClassyTask) -> Tuple[ClassyTask, int, int]:
        """
        If a checkpoint is present, recover the current training status.
        If not initialize everything properly

        Args:
            task {ClassyTask}: object consisting of all components a training requires
                               (meters, optimizers, model, loss etc.)

        Returns:
            task {ClassyTask}: updated task
            phase_idx {int}: phase index
            iteration_num: iteration number
        """

        phase_idx, iteration_num = -1, -1

        # Ensure that train loader exists. Will NOT exist if config.TEST_ONLY is True
        if "train" in task.dataloaders.keys():
            loader_key = "train"
        else:
            loader_key = "test"

        task.max_iteration = task.num_train_phases * len(
            task.dataloaders[loader_key])

        if task.checkpoint is not None:
            phase_idx = task.checkpoint["phase_idx"]
            task.train_phase_idx = task.checkpoint["train_phase_idx"]
            task.local_iteration_num = task.checkpoint["iteration_num"]
            task.iteration = task.checkpoint["iteration"]
        else:
            task.iteration = 0
            task.local_iteration_num = iteration_num

        num_iter_in_phase = len(task.dataloaders[loader_key])
        num_iter_in_epoch = num_iter_in_phase * task.num_train_phases_per_epoch

        num_samples = task.num_phase_samples(loader_key)
        task.start_time = time.time()
        task.batch_time = []
        task.metrics = {}
        logging.info(f"Training {task.num_epochs} epochs")
        logging.info(f"One epoch = {num_iter_in_epoch} iterations.")
        logging.info(f"Total {num_samples} samples in one epoch")

        if task.num_epochs != task.num_train_phases:
            logging.info(
                f"Training a total of {task.num_train_phases} train phases.")
            logging.info(f"One phase = {num_iter_in_phase} iterations.")

        logging.info(f"Total {task.max_iteration} iterations for training")

        return task, phase_idx, task.local_iteration_num
示例#3
0
    def _advance_phase(self, task: ClassyTask):
        """
        Advance the training phase to the next phase.
        - Updates the phase number,
        - resets the meters,
        - reset losses,
        - recreates the data iterator and destroys previous iterator
        - set the model to be in train or eval phase depending on what phase we are in
        - execute any optimizer update (normally learning rate updates etc at the end of
          an epoch)
        """
        # reset the meters at the beginning of the epoch
        for meter in task.meters:
            meter.reset()

        # reset the loss history for this epoch
        task.losses = []

        # advance the epoch num to be current
        task.phase_idx += 1
        phase = task.phases[task.phase_idx]
        task.train = True if phase["train"] else False
        if task.train:
            task.train_phase_idx += 1

        # get a new data iterator - delete the iterator at the beginning explicitly
        # so that all dataloader processes are cleaned up
        phase_type = "train" if phase["train"] else "test"
        # we are advancing to next epoch, so no need to compute start_iter,
        # just let it to be 0 inside of recreate_data_iterator. However, if we are just
        # starting from the resumed training, we want to compute_start_iter
        # again (if applicable) since we recreate the data iterator and delete
        # the old ones.
        compute_start_iter = False
        if task.checkpoint is not None and task.checkpoint["train_phase_idx"] == (
            task.train_phase_idx - 1
        ):
            compute_start_iter = True

        task.recreate_data_iterator(
            phase_type,
            epoch=task.phase_idx,
            compute_start_iter=compute_start_iter,
            train_phase_idx=task.train_phase_idx,
        )

        # set the model to train or eval depending on what phase we are in
        task.model.train(phase["train"])

        if task.train and task.train_phase_idx >= 0:
            task.optimizer.on_epoch(task.where)

        local_rank, _ = get_machine_local_and_dist_rank()
        logging.info(f"Phase advanced. Rank: {local_rank}")
示例#4
0
文件: dino_hooks.py 项目: zlapp/vissl
    def _build_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        Create the teacher: it is an exponential moving average of the student.
        """
        logging.info("Building momentum encoder")

        # Same architecture but do not apply stochastic depth
        # TODO: make drop_path_rate configurable for teacher
        task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][
            "DROP_PATH_RATE"] = 0.0
        task.loss.momentum_teacher = build_model(task.config["MODEL"],
                                                 task.config["OPTIMIZER"])
        task.loss.momentum_teacher.to(task.device)

        # Restore an hypothetical checkpoint
        if task.loss.checkpoint is not None:
            task.loss.load_state_dict(task.loss.checkpoint)
        # Initialize from the model
        else:
            task_model = get_no_ddp_model(task.model)
            teacher_model = get_no_ddp_model(task.loss.momentum_teacher)
            teacher_model.load_state_dict(task_model.state_dict())

        # Setup SyncBN (useful for the XCiT)
        task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(
            task.loss.momentum_teacher)
        task.loss.momentum_teacher = DistributedDataParallel(
            task.loss.momentum_teacher, device_ids=[task.device])

        # no gradients for teacher model
        for p in task.loss.momentum_teacher.parameters():
            p.requires_grad = False
示例#5
0
    def _build_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        Create the teacher: it is an exponential moving average of the student.
        """
        logging.info("Building momentum encoder")

        # - same architecture but do not apply stochastic depth
        task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][
            "DROP_PATH_RATE"] = 0
        task.loss.momentum_teacher = build_model(task.config["MODEL"],
                                                 task.config["OPTIMIZER"])
        task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(
            task.loss.momentum_teacher)
        task.loss.momentum_teacher.to(task.device)

        if get_world_size() > 1:
            task.loss.momentum_teacher = init_distributed_data_parallel_model(
                task.loss.momentum_teacher)

        # Restore an hypothetical checkpoint
        if task.loss.checkpoint is not None:
            task.loss.load_state_dict(task.loss.checkpoint)
        # Initialize from the model
        else:
            task.loss.momentum_teacher.load_state_dict(task.model.state_dict())
示例#6
0
    def _get_split_features(self, feat_names: List[str], cfg: AttrDict,
                            task: ClassyTask):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")

        out_features, out_targets = {}, {}
        for layer in feat_names:
            out_features[layer], out_targets[layer] = {}, {}

        while True:
            try:
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"
                input_sample = {
                    "input": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "target": torch.cat(sample["label"]).cpu().numpy(),
                    "inds": torch.cat(sample["data_idx"]).cpu().numpy(),
                }
                with torch.no_grad():
                    features = task.model(input_sample["input"])
                    flat_features_list = self._flatten_features_list(features)
                    num_images = input_sample["inds"].shape[0]
                    for num, layer in enumerate(feat_names):
                        feature = flat_features_list[num].cpu().numpy()
                        targets = input_sample["target"]
                        for idx in range(num_images):
                            index = input_sample["inds"][idx]
                            if not (index in out_features[layer]):
                                out_targets[layer][index] = targets[
                                    idx].reshape(-1)
                                out_features[layer][index] = feature[idx]
            except StopIteration:
                break
        barrier()

        output = {}
        for layer in feat_names:
            out_features[layer] = dict(sorted(out_features[layer].items()))
            out_targets[layer] = dict(sorted(out_targets[layer].items()))
            feats = np.array(list(out_features[layer].values()))
            N = feats.shape[0]
            output[layer] = {
                "features": feats.reshape(N, -1),
                "targets": np.array(list(out_targets[layer].values())),
                "inds": np.array(list(out_features[layer].keys())),
            }
        return output
示例#7
0
文件: dino_hooks.py 项目: zlapp/vissl
    def _update_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        EMA update
        Each teacher parameter becomes a weighted average of its old self and the
        newest student.
        """
        # Cosine schedule for the teacher momentum
        m = 1 - 0.5 * (1 - task.loss.loss_config.momentum) * (
            math.cos(math.pi * task.iteration / task.max_iteration) + 1)
        task.additional_log_data["dino_teacher_momentum"] = m

        task_model = get_no_ddp_model(task.model)
        teacher_model = get_no_ddp_model(task.loss.momentum_teacher)

        # EMA update for the teacher parameters
        for param_q, param_k in zip(task_model.parameters(),
                                    teacher_model.parameters()):
            param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
示例#8
0
文件: dino_hooks.py 项目: zlapp/vissl
    def update_teacher_temperature(self, task: tasks.ClassyTask) -> None:
        """
        Update the teacher temperature
        """
        if self.teacher_temp_schedule is None:
            teacher_temp_min = task.loss.loss_config["teacher_temp_min"]
            teacher_temp_max = task.loss.loss_config["teacher_temp_max"]
            teacher_temp_warmup_iters = task.loss.loss_config[
                "teacher_temp_warmup_iters"]
            self.teacher_temp_schedule = torch.cat((
                torch.linspace(teacher_temp_min, teacher_temp_max,
                               teacher_temp_warmup_iters),
                torch.ones(
                    max(0, task.max_iteration - teacher_temp_warmup_iters)) *
                teacher_temp_max,
            ))

        teacher_temp = self.teacher_temp_schedule[task.iteration].item()
        task.loss.teacher_temp = teacher_temp
        task.additional_log_data["dino_teacher_temp"] = teacher_temp
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        task.prepare()
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        task.on_start()
        while not task.done_training():
            task.on_phase_start()
            while True:
                try:
                    task.step()
                except StopIteration:
                    break
            task.on_phase_end()
        task.on_end()
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        local_variables = {}

        task.on_start(local_variables)
        while not task.done_training():
            task.on_phase_start(local_variables)
            while True:
                try:
                    task.step(self.use_gpu, local_variables)
                except StopIteration:
                    break
            task.on_phase_end(local_variables)
        task.on_end(local_variables)
示例#11
0
 def train(self, task: ClassyTask):
     task.hooks = task.hooks + [LimitedPhaseHook(self.num_phases)]
     try:
         super().train(task)
     except LimitedPhaseException:
         pass
示例#12
0
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        if is_distributed_training_run():
            task.init_distributed_data_parallel_model()

        local_variables = {}
        task.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
        best_acc = {
            'top1_acc': 0,
            'top1_epoch': 0,
            'top5_acc': 0,
            'top5_epoch': 0
        }
        epoch = 0
        while not task.done_training():
            task.advance_phase()

            # Start phase hooks
            task.run_hooks(local_variables,
                           ClassyHookFunctions.on_phase_start.name)
            while True:
                # Process next sample
                try:
                    task.train_step(self.use_gpu, local_variables)
                except StopIteration:
                    break

            logging.info("Syncing meters on phase end...")
            for meter in task.meters:
                meter.sync_state()
            logging.info("...meters synced")
            barrier()
            meter = task.run_hooks(local_variables,
                                   ClassyHookFunctions.on_phase_end.name)
            if meter is not None:
                if meter[0].value['top_1'] > best_acc['top1_acc']:
                    best_acc['top1_acc'] = meter[0].value['top_1']
                    best_acc['top5_acc'] = meter[0].value['top_5']
                    best_acc['top1_epoch'] = epoch
                    best_acc['top5_epoch'] = epoch
            epoch += 1

        task.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
        return best_acc
示例#13
0
    def _get_cluster_assignment_for_split(
        self, task: ClassyTask, split: str, output_folder: str
    ):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")
        dist_rank = torch.distributed.get_rank()

        cluster_assignments = {}
        soft_cluster_assignments = {}
        image_indices = []
        chunk_index, buffer_size = 0, 0

        task.data_iterator = iter(self.task.dataloaders[split.lower()])
        while True:
            try:
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"
                input_sample = {
                    "images": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "indices": torch.cat(sample["data_idx"]).cpu().numpy(),
                }
                with torch.no_grad():
                    outputs = task.model(input_sample["images"])
                    prototype_score = outputs[0][1]
                    prototype_index = prototype_score.argmax(dim=-1)
                    num_images = input_sample["indices"].shape[0]
                    buffer_size += num_images
                    for idx in range(num_images):
                        image_index = input_sample["indices"][idx]
                        cluster_assignments[image_index] = prototype_index[idx].item()
                        soft_cluster_assignments[
                            image_index
                        ] = prototype_score.cpu().numpy()
                        image_indices.append(image_index)

                if buffer_size >= self.cfg.EXTRACT_FEATURES.CHUNK_THRESHOLD >= 0:
                    self._save_extracted_prototypes(
                        soft_assignments=soft_cluster_assignments,
                        out_indices=image_indices,
                        dist_rank=dist_rank,
                        chunk_index=chunk_index,
                        split=split,
                        output_folder=output_folder,
                    )
                    soft_cluster_assignments.clear()
                    image_indices.clear()
                    chunk_index += 1
                    buffer_size = 0

            except StopIteration:
                if buffer_size:
                    self._save_extracted_prototypes(
                        soft_assignments=soft_cluster_assignments,
                        out_indices=image_indices,
                        dist_rank=dist_rank,
                        chunk_index=chunk_index,
                        split=split,
                        output_folder=output_folder,
                    )
                break
        return cluster_assignments
示例#14
0
    def _extract_split_features(
        self,
        feat_names: List[str],
        task: ClassyTask,
        split_name: str,
        output_folder: str,
    ):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")
        dist_rank = torch.distributed.get_rank()

        out_features, out_targets = {}, {}
        for feat_name in feat_names:
            out_features[feat_name], out_targets[feat_name] = {}, {}

        chunk_index, feature_buffer_size, count = 0, 0, 0
        while True:
            if count % 100 == 0:
                logging.info(f"Feature extraction iteration: {count}")
            try:
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"
                input_sample = {
                    "input": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "target": torch.cat(sample["label"]).cpu().numpy(),
                    "inds": torch.cat(sample["data_idx"]).cpu().numpy(),
                }
                with torch.no_grad():
                    features = task.model(input_sample["input"])
                    flat_features_list = self._flatten_features_list(features)
                    num_images = input_sample["inds"].shape[0]
                    feature_buffer_size += num_images
                    for num, feat_name in enumerate(feat_names):
                        feature = flat_features_list[num].cpu().numpy()
                        targets = input_sample["target"]
                        for idx in range(num_images):
                            index = input_sample["inds"][idx]
                            out_features[feat_name][index] = feature[idx]
                            out_targets[feat_name][index] = targets[idx].reshape(-1)

                if (
                    feature_buffer_size
                    >= self.cfg.EXTRACT_FEATURES.CHUNK_THRESHOLD
                    >= 0
                ):
                    self._save_extracted_features(
                        features=out_features,
                        targets=out_targets,
                        dist_rank=dist_rank,
                        chunk_index=chunk_index,
                        split=split_name,
                        output_folder=output_folder,
                    )
                    for layer_name in out_features.keys():
                        out_features[layer_name].clear()
                    chunk_index += 1
                    feature_buffer_size = 0

            except StopIteration:
                self._save_extracted_features(
                    features=out_features,
                    targets=out_targets,
                    dist_rank=dist_rank,
                    chunk_index=chunk_index,
                    split=split_name,
                    output_folder=output_folder,
                )
                break
            count += 1
示例#15
0
    def _extract_split_label_predictions(
        self,
        feat_names: List[str],
        task: ClassyTask,
        split_name: str,
        output_folder: str,
    ):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")
        dist_rank = torch.distributed.get_rank()

        feat_names = self._to_unique_feature_names(feat_names)
        out_predictions, out_targets, out_scores = {}, {}, {}
        for feat_name in feat_names:
            out_predictions[feat_name] = {}
            out_scores[feat_name] = {}
            out_targets[feat_name] = {}

        assert len(task.meters) > 0, "Please specify one meter to extract predictions"
        assert len(task.meters) == 1, "Please use only one meter to extract predictions"
        for meter in task.meters:
            assert hasattr(
                meter, "get_predictions"
            ), f"Meter {meter.name} doesn't implement get_predictions function"

        dataset = task.datasets[split_name.lower()]
        all_image_paths = dataset.get_image_paths()
        assert (
            len(all_image_paths) == 1
        ), "Multi-dataset not supported yet for label predictions."
        all_image_paths = all_image_paths[0]

        for count in itertools.count(start=0, step=1):
            try:
                if count % 100 == 0:
                    logging.info(f"Label prediction extraction iteration: {count}")
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"
                input_sample = {
                    "input": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "target": torch.cat(sample["label"]).cpu().numpy(),
                    "inds": torch.cat(sample["data_idx"]).cpu().numpy(),
                }
                with torch.no_grad():
                    # Send the input sample to the model, tracking also the
                    # last batch for the hooks to refer to
                    task.last_batch = SimpleNamespace()
                    model_output = task.model(input_sample["input"])
                    task.last_batch.sample = input_sample
                    task.last_batch.model_output = model_output

                    # Run hooks on forward pass
                    task.run_hooks(SSLClassyHookFunctions.on_forward.name)

                    # get the model predictions using the meter
                    if isinstance(model_output, list):
                        model_output_cpu = [x.cpu() for x in model_output]
                    else:
                        model_output_cpu = model_output.cpu()
                    for meter in task.meters:
                        meter.update(
                            model_output_cpu, sample["label"][0].detach().cpu()
                        )
                    predictions, pred_scores = task.meters[0].get_predictions(
                        model_output_cpu
                    )
                    num_images = input_sample["inds"].shape[0]
                    for num, layer_name in enumerate(feat_names):
                        pred = predictions[num]
                        score = pred_scores[num]
                        targets = input_sample["target"]
                        for idx in range(num_images):
                            index = input_sample["inds"][idx]
                            if not (index in out_predictions[layer_name]):
                                out_targets[layer_name][index] = targets[idx].reshape(
                                    -1
                                )
                                out_predictions[layer_name][index] = pred[idx]
                                out_scores[layer_name][index] = score[idx]
            except StopIteration:
                break

        # print the meters results. This can offer a validation
        # of the extracted predictions.
        self._sync_and_print_meters(task)
        # save the predictions, targets and image indices now
        self._save_extracted_label_predictions(
            all_image_paths=all_image_paths,
            predictions=out_predictions,
            confidence_scores=out_scores,
            targets=out_targets,
            dist_rank=dist_rank,
            split=split_name,
            output_folder=output_folder,
        )