Esempio n. 1
0
    def predict(self):
        """The testing process.
        """
        self.net.eval()
        dataloader = self.test_dataloader
        pbar = tqdm(dataloader, desc='test', ascii=True)

        epoch_log = EpochLog()
        for i, batch in enumerate(pbar):
            with torch.no_grad():
                test_dict = self._test_step(batch)
                loss = test_dict['loss']
                losses = test_dict.get('losses')
                metrics = test_dict.get('metrics')

            if (i + 1) == len(dataloader) and not dataloader.drop_last:
                batch_size = len(dataloader.dataset) % dataloader.batch_size
            else:
                batch_size = dataloader.batch_size
            epoch_log.update(batch_size, loss, losses, metrics)

            pbar.set_postfix(**epoch_log.on_step_end_log)
        test_log = epoch_log.on_epoch_end_log
        LOGGER.info(f'Test log: {test_log}.')
        return test_log
Esempio n. 2
0
    def _run_epoch(self, mode):
        """Run an epoch for training.
        Args:
            mode (str): The mode of running an epoch ('train' or 'valid').

        Returns:
            log (dict): The log information.
            batch (dict or sequence): The last batch of the data.
            outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs.
        """
        if mode == 'train':
            self.net.train()
            dataloader = self.train_dataloader
            dataloader_iterator = iter(dataloader)
            outter_pbar = tqdm(total=(
                (len(dataloader) + dataloader.grad_accumulation_steps() - 1) //
                dataloader.grad_accumulation_steps()),
                               desc='train',
                               ascii=True)
            inner_pbar = tqdm(total=math.ceil(
                dataloader.grad_accumulation_steps(0)),
                              desc='grad_accumulation',
                              leave=False,
                              ascii=True)

            epoch_log = EpochLog()
            for i in range(len(dataloader)):
                batch = next(dataloader_iterator)
                train_dict = self._train_step(batch)
                loss = train_dict.get('loss')
                if loss is None:
                    raise KeyError(
                        f"The train_dict must have the key named 'loss'. "
                        'Please check the returned keys as defined in MyTrainer._train_step().'
                    )
                losses = train_dict.get('losses')
                metrics = train_dict.get('metrics')
                outputs = train_dict.get('outputs')

                if self.use_amp:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        (scaled_loss /
                         dataloader.grad_accumulation_steps(i)).backward()
                else:
                    (loss / dataloader.grad_accumulation_steps(i)).backward()

                if (i + 1) % dataloader.grad_accumulation_steps() == 0 or (
                        i + 1) == len(dataloader):
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    if isinstance(self.lr_scheduler, CyclicLR):
                        self.lr_scheduler.step()
                    elif isinstance(self.lr_scheduler,
                                    CosineAnnealingWarmRestarts):
                        self.lr_scheduler.step((self.epoch - 1) +
                                               i / len(dataloader))

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size
                epoch_log.update(batch_size, loss, losses, metrics)

                inner_pbar.update()
                if (i + 1) % dataloader.grad_accumulation_steps() == 0 or (
                        i + 1) == len(dataloader):
                    outter_pbar.update()
                    outter_pbar.set_postfix(**epoch_log.on_step_end_log)
                    inner_pbar.close()
                    inner_pbar = tqdm(total=math.ceil(
                        dataloader.grad_accumulation_steps(i + 1)),
                                      desc='grad_accumulation',
                                      leave=False,
                                      ascii=True)
            outter_pbar.close()
            inner_pbar.close()
        else:
            self.net.eval()
            dataloader = self.valid_dataloader
            pbar = tqdm(dataloader, desc='valid', ascii=True)

            epoch_log = EpochLog()
            for i, batch in enumerate(pbar):
                with torch.no_grad():
                    valid_dict = self._valid_step(batch)
                    loss = valid_dict.get('loss')
                    if loss is None:
                        raise KeyError(
                            f"The valid_dict must have the key named 'loss'. "
                            'Please check the returned keys as defined in MyTrainer._valid_step().'
                        )
                    losses = valid_dict.get('losses')
                    metrics = valid_dict.get('metrics')
                    outputs = valid_dict.get('outputs')

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size
                epoch_log.update(batch_size, loss, losses, metrics)

                pbar.set_postfix(**epoch_log.on_step_end_log)
        return epoch_log.on_epoch_end_log, batch, outputs
Esempio n. 3
0
    def _run_epoch(self, mode):
        """Run an epoch for training.
        Args:
            mode (str): The mode of running an epoch ('train' or 'valid').
        Returns:
            log (dict): The log information.
            batch (dict or sequence): The last batch of the data.
            outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs.
        """
        for fn in self.metric_fns:
            fn.reset_count()
        if mode == 'train':
            self.net.train()
            dataloader = self.train_dataloader
            dataloader_iterator = iter(dataloader)
            outter_pbar = tqdm(total=len(dataloader), desc=mode, ascii=True)

            epoch_log = EpochLog()
            for i in range(len(dataloader)):
                batch = next(dataloader_iterator)
                train_dict = self._train_step(batch)
                loss = train_dict['loss']
                losses = train_dict.get('losses')
                metrics = train_dict.get('metrics')
                outputs = train_dict.get('outputs')

                lr = self.optimizer.param_groups[0]['lr']

                loss.backward()
                if self.grad_norm:
                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
                self.optimizer.step()
                self.optimizer.zero_grad()

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size

                epoch_log.update(batch_size, loss, losses, metrics, lr)

                outter_pbar.update()
                outter_pbar.set_postfix(**epoch_log.on_step_end_log)

            outter_pbar.close()

        else:
            self.net.eval()
            dataloader = self.valid_dataloader
            pbar = tqdm(dataloader, desc=mode, ascii=True)

            epoch_log = EpochLog()
            for i, batch in enumerate(pbar):
                with torch.no_grad():
                    valid_dict = self._valid_step(batch)
                    loss = valid_dict['loss']
                    losses = valid_dict.get('losses')
                    metrics = valid_dict.get('metrics')
                    outputs = valid_dict.get('outputs')

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size

                epoch_log.update(batch_size, loss, losses, metrics)

                pbar.set_postfix(**epoch_log.on_step_end_log)

        return epoch_log.on_epoch_end_log, batch, outputs
Esempio n. 4
0
    def _run_epoch(self, mode):
        """Run an epoch for training.
        Args:
            mode (str): The mode of running an epoch ('train' or 'valid').
        Returns:
            log (dict): The log information.
            batch (dict or sequence): The last batch of the data.
            outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs.
        """
        if mode == 'train':
            self.net.train()
            dataloader = self.train_dataloader
            dataloader_iterator = iter(dataloader)
            outter_pbar = tqdm(total=len(dataloader), desc=mode, ascii=True)

            epoch_log = EpochLog()
            for i in range(len(dataloader)):
                batch = next(dataloader_iterator)
                train_dict = self._train_step(batch)
                loss = train_dict['loss']
                losses = train_dict.get('losses')
                metrics = train_dict.get('metrics')
                outputs = train_dict.get('outputs')
                lr = self.optimizer.param_groups[0]['lr']

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                # if isinstance(self.lr_scheduler, (CyclicLR, OneCycleLR)):
                #     self.lr_scheduler.step()
                # elif isinstance(self.lr_scheduler, CosineAnnealingWarmRestarts):
                #     self.lr_scheduler.step((self.epoch - 1) + i / len(dataloader))

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size

                epoch_log.update(batch_size, loss, losses, metrics, lr)

                outter_pbar.update()
                outter_pbar.set_postfix(**epoch_log.on_step_end_log)

            outter_pbar.close()

        else:
            self.net.eval()
            dataloader = self.valid_dataloader
            pbar = tqdm(dataloader, desc=mode, ascii=True)

            epoch_log = EpochLog()
            for i, batch in enumerate(pbar):
                with torch.no_grad():
                    valid_dict = self._valid_step(batch)
                    loss = valid_dict['loss']
                    losses = valid_dict.get('losses')
                    metrics = valid_dict.get('metrics')
                    outputs = valid_dict.get('outputs')

                if (i + 1) == len(dataloader) and not dataloader.drop_last:
                    batch_size = len(
                        dataloader.dataset) % dataloader.batch_size
                else:
                    batch_size = dataloader.batch_size

                epoch_log.update(batch_size, loss, losses, metrics)

                pbar.set_postfix(**epoch_log.on_step_end_log)

        return epoch_log.on_epoch_end_log, batch, outputs