Exemplo n.º 1
0
    def __init__(self,
                 estimator: SimPLEEstimator,
                 logger: Logger,
                 checkpoint_metric: str,
                 best_checkpoint_str: str,
                 best_checkpoint_pattern: str,
                 latest_checkpoint_str: str,
                 latest_checkpoint_pattern: str,
                 delayed_best_model_saving: bool = True):
        """

        Args:
            estimator: Estimator, used to get experiment related data
            logger: Logger, used for logging
            checkpoint_metric: save model when the best value of this key has changed.
            best_checkpoint_str: path str format to save best checkpoint file
            best_checkpoint_pattern: regex pattern used to find the best checkpoint file
            latest_checkpoint_str: path str format to save best checkpoint file
            latest_checkpoint_pattern: regex pattern used to find the latest checkpoint file
            delayed_best_model_saving: if True, save best model after calling save_latest_checkpoint()
        """
        self.absolute_best_path = "best_checkpoint.pth"

        # metrics to keep track of
        self.monitor = MetricMonitor()
        self.monitor.track(key="mean_acc",
                           best_value=-np.inf,
                           mode=MetricMode.MAX,
                           prefix="test")
        self.monitor.track(key="mean_acc",
                           best_value=-np.inf,
                           mode=MetricMode.MAX,
                           prefix="validation")

        self.checkpoint_metric = checkpoint_metric

        # checkpoint path patterns
        self.best_checkpoint_str = best_checkpoint_str
        self.best_checkpoint_pattern = re.compile(best_checkpoint_pattern)

        self.latest_checkpoint_str = latest_checkpoint_str
        self.latest_checkpoint_pattern = re.compile(latest_checkpoint_pattern)

        # save estimator and logger
        # this will recover best metrics and register log hooks
        self.estimator = estimator
        self.logger = logger

        # assign flags
        self.delayed_save_best_model = delayed_best_model_saving
        self.is_best_model = False
Exemplo n.º 2
0
def train_model(model,
                criterion,
                optimizer,
                scheduler=None,
                save_path=None,
                num_epochs=25,
                iter_size=1):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = 0
    monitor = MetricMonitor()

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                targets = torch.tensor(samples['masks'],
                                       dtype=torch.long).cuda(async=True)

                # forward
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice = dice_value(outputs.data, targets.data, None)
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice', dice.data, inputs.shape[0])
                stream.set_description(f'epoch {epoch+1}/{num_epochs} | '
                                       f'{phase}: {monitor}')
            stream.close()

            epoch_loss = monitor.get_avg('loss')
            epoch_dice = monitor.get_avg('dice')

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_dice)

            # deep copy the model
            if (phase == 'valid') and (epoch_dice > best_dice):
                best_dice = epoch_dice
                best_model_wts = copy.deepcopy(model.state_dict())
                if save_path is not None:
                    path = save_path.format(best_dice)
                    torch.save(best_model_wts, path)
                    print('Weights of model saved at {}'.format(path))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Dice: {:.4f}'.format(best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Exemplo n.º 3
0
def train_model(model,
                criterion1,
                criterion2,
                criterion3,
                optimizer,
                scheduler=None,
                save_path=None,
                num_epochs=25,
                iter_size=1,
                compare='loss'):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val = -sys.maxsize
    monitor = MetricMonitor()

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                vectors = torch.tensor(samples['vectors'],
                                       dtype=torch.float).cuda(async=True)
                masks = torch.tensor(samples['masks'],
                                     dtype=torch.float).cuda(async=True)
                areas = torch.tensor(samples['areas'],
                                     dtype=torch.float).cuda(async=True)

                # forward
                outputs1, outputs2, outputs3 = model(inputs)
                loss1 = criterion1(inputs, outputs1, vectors, masks,
                                   areas) if criterion1 is not None else 0
                loss2 = criterion2(inputs, outputs2, vectors, masks,
                                   areas) if criterion2 is not None else 0
                loss3 = criterion3(inputs, outputs3, vectors, masks,
                                   areas) if criterion3 is not None else 0
                loss = loss1 + loss2 + loss3

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice1 = dice_value(outputs1.data,
                                   torch.unsqueeze(masks[:, 0], 1).data)
                dice3 = dice_value(outputs3.data, masks.data)
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice1', dice1.data, inputs.shape[0])
                monitor.update('dice3', dice3.data, inputs.shape[0])
                stream.set_description(f'epoch {epoch+1}/{num_epochs} | '
                                       f'{phase}: {monitor}')
            stream.close()

            epoch_val = monitor.get_avg('dice1') if compare == 'dice1' else \
                       (monitor.get_avg('dice3') if compare == 'dice3' else -monitor.get_avg('loss'))

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_val)

            # deep copy the model
            if (phase == 'valid') and (epoch_val > best_val):
                best_val = epoch_val
                best_model_wts = copy.deepcopy(model.state_dict())
                if save_path is not None:
                    path = save_path.format((epoch + 1),
                                            optimizer.param_groups[0]['lr'],
                                            abs(best_val))
                    torch.save(best_model_wts, path)
                    print('Weights of model saved at {}'.format(path))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val: {:.4f}'.format(abs(best_val)))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Exemplo n.º 4
0
def train_model(model,
                criterion,
                optimizer,
                scheduler=None,
                model_save_path=None,
                optim_save_path=None,
                log_save_path=None,
                num_epochs=25,
                iter_size=1,
                compare_Loss=False):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val = -sys.maxsize
    monitor = MetricMonitor()
    log = open(log_save_path, 'a') if log_save_path is not None else \
          type('dummy', (object,), {'write': lambda x,y:0, 'flush': lambda x:0, 'close': lambda x:0})()
    log.write(f'Training start at {time.strftime("%Y-%m-%d %H:%M")}\n\n')

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                masks = torch.tensor(samples['masks'],
                                     dtype=torch.long).cuda(async=True)
                centroids = torch.tensor(samples['centroids'],
                                         dtype=torch.long).cuda(async=True)
                targets = masks + centroids

                # forward
                outputs = model(inputs)
                # outputs = F.avg_pool2d(outputs, 4, 4)
                loss = criterion(outputs, targets)

                # out5,out4,out3,out2,out1,out_fuse = model(inputs)
                # loss5 = criterion(out5, targets)
                # loss4 = criterion(out4, targets)
                # loss3 = criterion(out3, targets)
                # loss2 = criterion(out2, targets)
                # loss1 = criterion(out1, targets)
                # loss_fuse = criterion(out_fuse, targets)
                # loss = loss5 + loss4 + loss3 + loss2 + loss1 + loss_fuse

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice = ce_dice_value(outputs.data, targets.data, [0, 0, 1, 0])
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice', dice.data, inputs.shape[0])
                stream.set_description(
                    f'epoch {epoch+1}/{num_epochs} | {phase}: {monitor}')
            stream.close()

            epoch_loss = monitor.get_avg('loss')
            epoch_dice = monitor.get_avg('dice')
            epoch_val = epoch_dice if not compare_Loss else -epoch_loss

            log.write(
                f'epoch {epoch+1}/{num_epochs} | {phase}: {monitor} | lr {optimizer.param_groups[0]["lr"]:.0e}\n'
            )

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_val)

            # save the model and optimizer
            if (phase == 'valid') and (epoch_val > best_val):
                best_val = epoch_val
                best_model_wts = copy.deepcopy(model.state_dict())
                if model_save_path is not None:
                    path = model_save_path.format((epoch + 1), abs(best_val))
                    torch.save(best_model_wts, path)
                    print(f'Weights of model saved at {path}')
                    log.write(f'Weights of model saved at {path}\n')
            if (phase == 'valid') and (optim_save_path is not None):
                path = optim_save_path.format((epoch + 1),
                                              optimizer.param_groups[0]['lr'])
                torch.save(optimizer.state_dict(), path)
            log.flush()

        log.write('\n')
        print()

    time_elapsed = time.time() - since
    print(
        f'Training complete in {(time_elapsed//60):.0f}m {(time_elapsed%60):.0f}s'
    )
    print(f'Best Val: {abs(best_val):.4f}')
    log.write(
        f'Training complete in {(time_elapsed//60):.0f}m {(time_elapsed%60):.0f}s\n'
    )
    log.write(f'Best Val: {abs(best_val):f}\n')
    log.close()

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Exemplo n.º 5
0
class CheckpointSaver:
    def __init__(self,
                 estimator: SimPLEEstimator,
                 logger: Logger,
                 checkpoint_metric: str,
                 best_checkpoint_str: str,
                 best_checkpoint_pattern: str,
                 latest_checkpoint_str: str,
                 latest_checkpoint_pattern: str,
                 delayed_best_model_saving: bool = True):
        """

        Args:
            estimator: Estimator, used to get experiment related data
            logger: Logger, used for logging
            checkpoint_metric: save model when the best value of this key has changed.
            best_checkpoint_str: path str format to save best checkpoint file
            best_checkpoint_pattern: regex pattern used to find the best checkpoint file
            latest_checkpoint_str: path str format to save best checkpoint file
            latest_checkpoint_pattern: regex pattern used to find the latest checkpoint file
            delayed_best_model_saving: if True, save best model after calling save_latest_checkpoint()
        """
        self.absolute_best_path = "best_checkpoint.pth"

        # metrics to keep track of
        self.monitor = MetricMonitor()
        self.monitor.track(key="mean_acc",
                           best_value=-np.inf,
                           mode=MetricMode.MAX,
                           prefix="test")
        self.monitor.track(key="mean_acc",
                           best_value=-np.inf,
                           mode=MetricMode.MAX,
                           prefix="validation")

        self.checkpoint_metric = checkpoint_metric

        # checkpoint path patterns
        self.best_checkpoint_str = best_checkpoint_str
        self.best_checkpoint_pattern = re.compile(best_checkpoint_pattern)

        self.latest_checkpoint_str = latest_checkpoint_str
        self.latest_checkpoint_pattern = re.compile(latest_checkpoint_pattern)

        # save estimator and logger
        # this will recover best metrics and register log hooks
        self.estimator = estimator
        self.logger = logger

        # assign flags
        self.delayed_save_best_model = delayed_best_model_saving
        self.is_best_model = False

    @property
    def estimator(self) -> SimPLEEstimator:
        return self._estimator

    @estimator.setter
    def estimator(self, estimator: SimPLEEstimator) -> None:
        self._estimator = estimator

        # recover best value
        checkpoint_path = self.estimator.exp_args.checkpoint_path
        if checkpoint_path is not None:
            print(f"Recovering best metrics from {checkpoint_path}...")
            self.recover_metrics(
                torch.load(checkpoint_path, map_location=self.device))

    @property
    def logger(self) -> Logger:
        return self._logger

    @logger.setter
    def logger(self, logger: Logger) -> None:
        self._logger = logger

        # register log hooks
        print("Registering log hooks...")
        self.logger.register_log_hook(self.update_best_metric,
                                      logger=self.logger)

    @property
    def checkpoint_metric(self) -> str:
        return self._checkpoint_metric

    @checkpoint_metric.setter
    def checkpoint_metric(self, checkpoint_metric: str) -> None:
        assert checkpoint_metric in self.monitor, f"{checkpoint_metric} is not in metric monitor"

        self._checkpoint_metric = checkpoint_metric

    @property
    def log_dir(self) -> str:
        return self.estimator.exp_args.log_dir

    @property
    def best_full_checkpoint_str(self) -> str:
        return str(Path(self.log_dir) / self.best_checkpoint_str)

    @property
    def latest_full_checkpoint_str(self) -> str:
        return str(Path(self.log_dir) / self.latest_checkpoint_str)

    @property
    def device(self) -> torch.device:
        return self.estimator.device

    @property
    def global_step(self) -> int:
        return self.estimator.global_step

    @property
    def num_latest_checkpoints_kept(self) -> Optional[int]:
        return self.estimator.exp_args.num_latest_checkpoints_kept

    @property
    def is_save_latest_checkpoint(self) -> bool:
        return self.num_latest_checkpoints_kept is None or self.num_latest_checkpoints_kept > 0

    @property
    def is_remove_old_checkpoint(self) -> bool:
        return self.num_latest_checkpoints_kept is not None and self.num_latest_checkpoints_kept > 0

    def save_checkpoint(self,
                        checkpoint: Dict[str, Any],
                        checkpoint_path: Union[str, Path],
                        is_logger_save: bool = False) -> Path:
        checkpoint_path = str(checkpoint_path)
        torch.save(checkpoint, checkpoint_path)

        print(f"Checkpoint saved to \"{checkpoint_path}\"", flush=True)

        if is_logger_save:
            self.logger.save(checkpoint_path)

        return Path(checkpoint_path)

    def save_best_checkpoint(self,
                             checkpoint: Optional[Dict[str, any]] = None,
                             is_logger_save: bool = False,
                             **kwargs) -> Path:
        if checkpoint is None:
            checkpoint = self.get_checkpoint()

        checkpoint_path = self.save_checkpoint(
            checkpoint_path=self.best_full_checkpoint_str.format(**kwargs),
            checkpoint=checkpoint,
            is_logger_save=is_logger_save)
        # reset flag
        self.is_best_model = False

        return checkpoint_path

    def save_latest_checkpoint(self,
                               checkpoint: Optional[Dict[str, any]] = None,
                               is_logger_save: bool = False,
                               **kwargs) -> Optional[Path]:
        checkpoint_path: Optional[Path] = None

        if self.is_save_latest_checkpoint:
            if checkpoint is None:
                checkpoint = self.get_checkpoint()

            # save new checkpoint
            checkpoint_path = self.save_checkpoint(
                checkpoint_path=self.latest_full_checkpoint_str.format(
                    **kwargs),
                checkpoint=checkpoint,
                is_logger_save=is_logger_save)

            # cleanup old checkpoints
            self.cleanup_checkpoints()

        if self.delayed_save_best_model and self.is_best_model:
            self.save_best_checkpoint(**kwargs)

        return checkpoint_path

    def get_checkpoint(self) -> Dict[str, Any]:
        checkpoint = self.estimator.get_checkpoint()

        # add best metrics
        checkpoint.update({"monitor_state": self.monitor.state_dict()})

        return checkpoint

    def update_best_checkpoint(self) -> None:
        """
        Update the logged metrics for the best checkpoint

        Returns:

        """
        best_checkpoint_path = self.find_best_checkpoint_path()

        if best_checkpoint_path is None:
            warnings.warn("Cannot find best checkpoint")
            return

        best_checkpoint_path = str(best_checkpoint_path)
        best_checkpoint = torch.load(best_checkpoint_path,
                                     map_location=self.device)

        # update best metrics
        best_checkpoint.update({"monitor_state": self.monitor.state_dict()})

        self.save_checkpoint(checkpoint_path=str(
            Path(self.log_dir) / self.absolute_best_path),
                             checkpoint=best_checkpoint)

    def find_best_checkpoint_path(self, checkpoint_dir: Optional[str] = None, ignore_absolute_best: bool = True) \
            -> Optional[Path]:
        if checkpoint_dir is None:
            checkpoint_dir = self.log_dir

        abs_best_path = Path(checkpoint_dir) / self.absolute_best_path

        if not ignore_absolute_best and abs_best_path.is_file():
            # if not ignoring absolute best path and the path is a file, return the absolute best file path
            return abs_best_path

        checkpoint_path = find_checkpoint_path(
            checkpoint_dir, step_filter=self.best_checkpoint_pattern)

        if checkpoint_path is None:
            checkpoint_path = self.find_latest_checkpoint_path(
                checkpoint_dir=checkpoint_dir)

        return checkpoint_path

    def find_latest_checkpoint_path(self,
                                    checkpoint_dir: Optional[str] = None
                                    ) -> Optional[Path]:
        if checkpoint_dir is None:
            checkpoint_dir = self.log_dir

        return find_checkpoint_path(checkpoint_dir,
                                    step_filter=self.latest_checkpoint_pattern)

    def update_best_metric(self, log_info: Dict[str, Any],
                           logger: Logger) -> None:
        updated_dict = self.monitor.update_metrics(log_info)

        for updated_key, new_best_value in updated_dict.items():
            metric_dict = self.monitor[updated_key]

            translated_key = metric_dict["key"]

            # if new_best_value is better than current best value
            logger.log({translated_key: new_best_value}, step=self.global_step)

            if self.checkpoint_metric == updated_key:
                self.is_best_model = True

        if self.is_best_model and not self.delayed_save_best_model:
            # if not delayed_save_best_model save, then save checkpoint
            self.save_best_checkpoint(global_step=self.global_step)

    def recover_checkpoint(self,
                           checkpoint: Dict[str, Any],
                           recover_optimizer: bool = True,
                           recover_train_progress: bool = True) -> None:
        self.recover_metrics(checkpoint=checkpoint)

        self.estimator.load_checkpoint(
            checkpoint=checkpoint,
            recover_optimizer=recover_optimizer,
            recover_train_progress=recover_train_progress)

    def recover_metrics(self, checkpoint: Dict[str, Any]) -> None:
        if "monitor_state" in checkpoint:
            monitor_state = checkpoint["monitor_state"]
        else:
            # for backward compatibility
            monitor_state = {
                "validation/mean_acc": checkpoint.get("best_val_acc", -np.inf),
                "test/mean_acc": checkpoint.get("best_test_acc", -np.inf),
            }

        self.monitor.load_state_dict(monitor_state)

    def cleanup_checkpoints(self) -> None:
        if not self.is_remove_old_checkpoint:
            # do nothing if the model do not save latest checkpoints or if all checkpoints are kept
            return

        checkpoint_paths = find_all_files(
            checkpoint_dir=self.log_dir,
            search_pattern=self.latest_checkpoint_pattern)

        # sort by recency (largest step first)
        checkpoint_paths.sort(key=lambda x: int(
            re.search(self.latest_checkpoint_pattern, x.name).group(1)),
                              reverse=True)

        # remove old checkpoints
        for checkpoint_path in checkpoint_paths[self.
                                                num_latest_checkpoints_kept:]:
            print(f"Removing old checkpoint \"{checkpoint_path}\"", flush=True)
            checkpoint_path.unlink()