def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        local_variables = {}

        task.on_start(local_variables)
        while not task.done_training():
            task.on_phase_start(local_variables)
            while True:
                try:
                    task.step(self.use_gpu, local_variables)
                except StopIteration:
                    break
            task.on_phase_end(local_variables)
        task.on_end(local_variables)
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        task.prepare()
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        task.on_start()
        while not task.done_training():
            task.on_phase_start()
            while True:
                try:
                    task.step()
                except StopIteration:
                    break
            task.on_phase_end()
        task.on_end()
Exemplo n.º 3
0
 def on_phase_end(self, local_variables):
     logging.info("Syncing meters on phase end...")
     for meter in self.meters:
         meter.sync_state()
     logging.info("...meters synced")
     barrier()
     self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
Exemplo n.º 4
0
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        if is_distributed_training_run():
            task.init_distributed_data_parallel_model()

        local_variables = {}
        task.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
        best_acc = {
            'top1_acc': 0,
            'top1_epoch': 0,
            'top5_acc': 0,
            'top5_epoch': 0
        }
        epoch = 0
        while not task.done_training():
            task.advance_phase()

            # Start phase hooks
            task.run_hooks(local_variables,
                           ClassyHookFunctions.on_phase_start.name)
            while True:
                # Process next sample
                try:
                    task.train_step(self.use_gpu, local_variables)
                except StopIteration:
                    break

            logging.info("Syncing meters on phase end...")
            for meter in task.meters:
                meter.sync_state()
            logging.info("...meters synced")
            barrier()
            meter = task.run_hooks(local_variables,
                                   ClassyHookFunctions.on_phase_end.name)
            if meter is not None:
                if meter[0].value['top_1'] > best_acc['top1_acc']:
                    best_acc['top1_acc'] = meter[0].value['top_1']
                    best_acc['top5_acc'] = meter[0].value['top_5']
                    best_acc['top1_epoch'] = epoch
                    best_acc['top5_epoch'] = epoch
            epoch += 1

        task.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
        return best_acc
    def on_phase_end(self):
        self.log_phase_end("train")

        logging.info("Syncing meters on phase end...")
        for meter in self.meters:
            meter.sync_state()
        logging.info("...meters synced")
        barrier()

        for hook in self.hooks:
            hook.on_phase_end(self)
        self.perf_log = []

        self.log_phase_end("total")
Exemplo n.º 6
0
    def _get_split_features(self, feat_names: List[str], cfg: AttrDict,
                            task: ClassyTask):
        task.model.eval()
        logging.info("Model set to eval mode during feature extraction...")

        out_features, out_targets = {}, {}
        for layer in feat_names:
            out_features[layer], out_targets[layer] = {}, {}

        while True:
            try:
                sample = next(task.data_iterator)
                assert isinstance(sample, dict)
                assert "data_idx" in sample, "Indices not passed"
                input_sample = {
                    "input": torch.cat(sample["data"]).cuda(non_blocking=True),
                    "target": torch.cat(sample["label"]).cpu().numpy(),
                    "inds": torch.cat(sample["data_idx"]).cpu().numpy(),
                }
                with torch.no_grad():
                    features = task.model(input_sample["input"])
                    flat_features_list = self._flatten_features_list(features)
                    num_images = input_sample["inds"].shape[0]
                    for num, layer in enumerate(feat_names):
                        feature = flat_features_list[num].cpu().numpy()
                        targets = input_sample["target"]
                        for idx in range(num_images):
                            index = input_sample["inds"][idx]
                            if not (index in out_features[layer]):
                                out_targets[layer][index] = targets[
                                    idx].reshape(-1)
                                out_features[layer][index] = feature[idx]
            except StopIteration:
                break
        barrier()

        output = {}
        for layer in feat_names:
            out_features[layer] = dict(sorted(out_features[layer].items()))
            out_targets[layer] = dict(sorted(out_targets[layer].items()))
            feats = np.array(list(out_features[layer].values()))
            N = feats.shape[0]
            output[layer] = {
                "features": feats.reshape(N, -1),
                "targets": np.array(list(out_targets[layer].values())),
                "inds": np.array(list(out_features[layer].keys())),
            }
        return output
Exemplo n.º 7
0
    def _run_step(self, state, local_variables, use_gpu):
        # Check for training complete but only terminate when the last phase is done
        if state.task.done_training() and state.advance_to_next_phase:
            raise StopIteration

        if state.advance_to_next_phase:
            state.task.advance_phase()

            # Start phase hooks
            state.task.run_hooks(local_variables,
                                 ClassyHookFunctions.on_phase_start.name)

            state.advance_to_next_phase = False

        # Process one train step
        try:
            if state.skip_current_phase:
                state.advance_to_next_phase = True
                state.skip_current_phase = False  # Reset flag
            else:
                state.task.train_step(use_gpu, local_variables)
        except StopIteration:
            state.advance_to_next_phase = True
        if state.advance_to_next_phase:
            logging.info("Syncing meters on phase end...")
            for meter in state.task.meters:
                meter.sync_state()
            logging.info("...meters synced")
            barrier()
            # Phase complete
            # NOTE: this is a good time to checkpoint, as it guarantees
            # that loading from checkpoint will properly advance the phase.
            state.task.run_hooks(local_variables,
                                 ClassyHookFunctions.on_phase_end.name)

        progress_rate = None  # using None to signal 'unknown'
        perf_stats = local_variables.get("perf_stats", None)
        if perf_stats is not None:
            batch_time = perf_stats._cuda_stats[
                "train_step_total"].smoothed_value
            if batch_time is not None and batch_time > 0.0:
                # rate = number of mini-batches per second
                progress_rate = 1.0 / batch_time

        progress_stats = self._ClassyWorkerStats(progress_rate)
        return state, progress_stats
Exemplo n.º 8
0
    def on_phase_end(self):
        self.log_phase_end("train")

        if self.train:
            self.optimizer.on_epoch(where=self.where)

        logging.debug("Syncing losses on phase end...")
        self.synchronize_losses()
        logging.debug("...losses synced")

        logging.debug("Syncing meters on phase end...")
        for meter in self.meters:
            meter.sync_state()
        logging.debug("...meters synced")
        barrier()

        for hook in self.hooks:
            hook.on_phase_end(self)
        self.perf_log = []

        self.log_phase_end("total")
Exemplo n.º 9
0
    def train(self):
        """
        The train workflow. We get the training loop to use (vissl default is
        standard_train_step) but the user can create their own training loop
        and specify the name TRAINER.TRAIN_STEP_NAME

        The training happens:
        1. Execute any hooks at the start of training (mostly resets the variable like
           iteration num phase_num etc)
        2. For each epoch (train or test), run the hooks at the start of an epoch. Mostly
           involves setting things like timer, setting dataloader epoch etc
        3. Execute the training loop (1 training iteration) involving forward, loss, backward,
           optimizer update, metrics collection etc.
        4. At the end of epoch, sync meters and execute hooks at the end of phase. Involves
           things like checkpointing model, logging timers, logging to tensorboard etc
        """
        train_step_fn = get_train_step(self.cfg["TRAINER"]["TRAIN_STEP_NAME"])
        self.task.prepare(pin_memory=self.cfg.DATA.PIN_MEMORY)
        self.task.init_distributed_data_parallel_model()

        # Find what phase, train_phase_idx, local_iteration_num we are starting from.
        # Recover it from the checkpoint (if available)
        task, phase_idx, iteration_num = self._init_training_state(
            self.cfg, self.task)

        # Good to go, (re) start training
        task.run_hooks(SSLClassyHookFunctions.on_start.name)

        if is_primary():
            logging.info("Model is:\n {}".format(task.model))
            logging.info("Loss is: {}".format(task.loss))
        logging.info("Starting training....")

        while phase_idx + 1 < len(task.phases):
            self._advance_phase(task)  # advances task.phase_idx
            phase_idx += 1
            iteration_num += 1
            task.local_iteration_num = iteration_num  # iteration_num=0 at this step
            task.run_hooks(SSLClassyHookFunctions.on_phase_start.name)
            while True:
                try:
                    if self.cfg.MODEL.CUDA_CACHE.CLEAR_CUDA_CACHE and (
                            iteration_num %
                            self.cfg.MODEL.CUDA_CACHE.CLEAR_FREQ == 0):
                        logging.info(
                            f"Emptying CUDA cache at step count: {iteration_num}"
                        )
                        torch.cuda.empty_cache()
                        logging.info("CUDA cache cleared")
                    task = train_step_fn(task)
                    iteration_num += 1
                    task.local_iteration_num = iteration_num
                    task.run_hooks(SSLClassyHookFunctions.on_step.name)
                except StopIteration:
                    break
            for meter in task.meters:
                meter.sync_state()
            logging.info("Meters synced")
            barrier()
            task.run_hooks(SSLClassyHookFunctions.on_phase_end.name)

        task.run_hooks(SSLClassyHookFunctions.on_end.name)
        if hasattr(task, "data_iterator"):
            del task.data_iterator
            gc.collect()
        if hasattr(task, "dataloaders"):
            del task.dataloaders
            gc.collect()