def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        local_variables = {}

        task.on_start(local_variables)
        while not task.done_training():
            task.on_phase_start(local_variables)
            while True:
                try:
                    task.step(self.use_gpu, local_variables)
                except StopIteration:
                    break
            task.on_phase_end(local_variables)
        task.on_end(local_variables)
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        task.prepare()
        assert isinstance(task, ClassyTask)

        # make sure all the workers start training at the same time
        # this helps catch hangs which would have happened elsewhere
        barrier()

        task.on_start()
        while not task.done_training():
            task.on_phase_start()
            while True:
                try:
                    task.step()
                except StopIteration:
                    break
            task.on_phase_end()
        task.on_end()
示例#3
0
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        if is_distributed_training_run():
            task.init_distributed_data_parallel_model()

        local_variables = {}
        task.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
        best_acc = {
            'top1_acc': 0,
            'top1_epoch': 0,
            'top5_acc': 0,
            'top5_epoch': 0
        }
        epoch = 0
        while not task.done_training():
            task.advance_phase()

            # Start phase hooks
            task.run_hooks(local_variables,
                           ClassyHookFunctions.on_phase_start.name)
            while True:
                # Process next sample
                try:
                    task.train_step(self.use_gpu, local_variables)
                except StopIteration:
                    break

            logging.info("Syncing meters on phase end...")
            for meter in task.meters:
                meter.sync_state()
            logging.info("...meters synced")
            barrier()
            meter = task.run_hooks(local_variables,
                                   ClassyHookFunctions.on_phase_end.name)
            if meter is not None:
                if meter[0].value['top_1'] > best_acc['top1_acc']:
                    best_acc['top1_acc'] = meter[0].value['top_1']
                    best_acc['top5_acc'] = meter[0].value['top_5']
                    best_acc['top1_epoch'] = epoch
                    best_acc['top5_epoch'] = epoch
            epoch += 1

        task.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
        return best_acc