示例#1
0
    def run(self, train_set, valid_set, epochs: int, batch_size: int, num_workers: int = 0, device: str = 'cuda', **kwargs):  # pylint: disable=unused-argument

        assert isinstance(train_set, torch.utils.data.Dataset)
        assert isinstance(valid_set, torch.utils.data.Dataset)
        assert isinstance(epochs, int)
        assert isinstance(batch_size, int)
        assert isinstance(num_workers, int)
        assert device.startswith('cuda') or device == 'cpu'

        logger = kwargs.get('logger', None)

        self.backbone = self.backbone.to(device)
        self.projector = self.projector.to(device)

        train_loader = get_dataloader(train_set, batch_size, num_workers=num_workers)
        valid_loader = get_dataloader(valid_set, batch_size, num_workers=num_workers)

        with tqdm.tqdm(**get_tqdm_config(total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader, device=device)
                valid_history = self.evaluate(valid_loader, device=device)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss'),
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if self.metrics is not None:
                    raise NotImplementedError

                # 3. TensorBoard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(
                            main_tag=metric_name,
                            tag_scalar_dict=metric_dict,
                            global_step=epoch
                        )
                        if self.scheduler is not None:
                            self.writer.add_scalar(
                                tag='lr',
                                scalar_value=self.scheduler.get_last_lr()[0],
                                global_step=epoch
                            )

                # 4. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history)
                    if kwargs.get('save_every', False):
                        new_ckpt = os.path.join(self.checkpoint_dir, f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt')
                        self.save_checkpoint(new_ckpt, epoch=epoch, **epoch_history)

                # 5. Update learning rate scheduler
                if self.scheduler is not None:
                    self.scheduler.step()

               # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if 'test_set' in kwargs.keys():
            test_loader = get_dataloader(kwargs.get('test_set'), batch_size=batch_size, num_workers=num_workers)
            self.test(test_loader, device=device, logger=logger)
示例#2
0
    def run(self,
            train_set,
            valid_set,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            **kwargs):
        """Train, evaluate and optionally test."""

        logger = kwargs.get('logger', None)

        self.backbone.to(self.local_rank)
        self.classifier.to(self.local_rank)

        train_loader = balanced_loader(train_set,
                                       batch_size,
                                       num_workers=num_workers,
                                       shuffle=False,
                                       pin_memory=False)
        valid_loader = DataLoader(valid_set,
                                  batch_size,
                                  num_workers=num_workers,
                                  shuffle=True,
                                  drop_last=False,
                                  pin_memory=False)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader)
                valid_history = self.evaluate(valid_loader)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if isinstance(self.metrics, dict):
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history[metric_name],
                            'valid': valid_history[metric_name],
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss <= best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    if self.local_rank == 0:
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)

                # 5. Update learning rate scheduler (optional)
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch,
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if self.local_rank == 0:
            if 'test_set' in kwargs.keys():
                test_set = kwargs['test_set']
                test_loader = DataLoader(test_set,
                                         batch_size,
                                         num_workers=num_workers,
                                         shuffle=True,
                                         drop_last=False,
                                         pin_memory=False)
                self.test(test_loader, logger=logger)
    def run(self,
            train_set,
            eval_set,
            test_set: torch.utils.data.Dataset = None,
            save_every: int = 10,
            finetune: bool = False,
            **kwargs):  # pylint: disable=unused-argument

        epochs = self.epochs
        batch_size = self.batch_size
        num_workers = self.num_workers

        if not self.prepared:
            raise RuntimeError("Training not prepared.")

        # DataLoader (train, val, test)
        sampler = DistributedSampler(train_set) if self.distributed else None
        shuffle = not self.distributed
        train_loader = DataLoader(train_set,
                                  batch_size=batch_size,
                                  sampler=sampler,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  drop_last=False,
                                  pin_memory=True)
        eval_loader = DataLoader(eval_set,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False,
                                 pin_memory=True)

        # Logging
        logger = kwargs.get('logger', None)

        # Supervised training
        best_eval_loss = float('inf')
        best_epoch = 0

        for epoch in range(1, epochs + 1):

            if self.distributed:
                sampler.set_epoch(epoch)

            # Train & evaluate
            train_history = self.train(train_loader, finetune=finetune)
            eval_history = self.evaluate(eval_loader)
            epoch_history = collections.defaultdict(dict)
            for k, v1 in train_history.items():
                epoch_history[k]['train'] = v1
                try:
                    v2 = eval_history[k]
                    epoch_history[k]['eval'] = v2
                except KeyError:
                    continue

            # Write TensorBoard summary
            if self.writer is not None:
                for k, v in epoch_history.items():
                    self.writer.add_scalars(k, v, global_step=epoch)
                if self.scheduler is not None:
                    lr = self.scheduler.get_last_lr()[0]
                    self.writer.add_scalar('lr', lr, global_step=epoch)

            # Save best model checkpoint
            eval_loss = eval_history['loss']
            if eval_loss <= best_eval_loss:
                best_eval_loss = eval_loss
                best_epoch = epoch
                if self.local_rank == 0:
                    ckpt = os.path.join(self.ckpt_dir, f"ckpt.best.pth.tar")
                    self.save_checkpoint(ckpt, epoch=epoch)

            # Save intermediate model checkpoints
            if (epoch % save_every == 0) & (self.local_rank == 0):
                ckpt = os.path.join(self.ckpt_dir, f"ckpt.{epoch}.pth.tar")
                self.save_checkpoint(ckpt, epoch=epoch)

            # Write logs
            log = make_epoch_description(
                history=epoch_history,
                current=epoch,
                total=epochs,
                best=best_epoch,
            )
            if logger is not None:
                logger.info(log)

            # Update learning rate
            if self.scheduler is not None:
                self.scheduler.step()

        # Save final model checkpoint
        ckpt = os.path.join(self.ckpt_dir, f"ckpt.last.pth.tar")
        self.save_checkpoint(ckpt, epoch=epoch)

        # Test (optional)
        if test_set is not None:
            test_loader = DataLoader(test_set,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     drop_last=False,
                                     pin_memory=False)
        test_history = self.evaluate(test_loader)
        if (self.local_rank == 0) & (logger is not None):
            log = "Test: "
            for k, v in test_history.items():
                log += f" {k}: {v:.4f} |"
            logger.info(log)
示例#4
0
    def run(self,
            train_set,
            valid_set,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            device: str = 'cuda',
            **kwargs):
        """Train, evaluate and optionally test."""

        assert isinstance(train_set, torch.utils.data.Dataset)
        assert isinstance(valid_set, torch.utils.data.Dataset)
        assert isinstance(epochs, int)
        assert isinstance(batch_size, int)
        assert isinstance(num_workers, int)
        assert device.startswith('cuda') or device == 'cpu'

        logger = kwargs.get('logger', None)
        disable_mixup = kwargs.get('disable_mixup', False)

        self.backbone = self.backbone.to(device)
        self.classifier = self.classifier.to(device)

        balance = kwargs.get('balance', False)
        if logger is not None:
            logger.info(f"Class balance: {balance}")
        shuffle = not balance

        train_loader = get_dataloader(train_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=shuffle,
                                      balance=balance)
        valid_loader = get_dataloader(valid_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      balance=False)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            # Determine model selection metric. Defaults to 'loss'.
            eval_metric = kwargs.get('eval_metric', 'loss')
            if eval_metric == 'loss':
                best_metric_val = float('inf')
            elif eval_metric in [
                    'accuracy', 'precision', 'recall', 'f1', 'auroc', 'auprc'
            ]:
                best_metric_val = 0
            else:
                raise NotImplementedError

            best_epoch = 0
            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                if disable_mixup:
                    train_history = self.train(train_loader, device)
                else:
                    train_history = self.train_with_mixup(train_loader, device)
                valid_history = self.evaluate(valid_loader, device)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if isinstance(self.metrics, dict):
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history[metric_name],
                            'valid': valid_history[metric_name],
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4. Save model if it is the current best
                metric_val = epoch_history[eval_metric]['valid']
                if eval_metric == 'loss':
                    if metric_val <= best_metric_val:
                        best_metric_val = metric_val
                        best_epoch = epoch
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                elif eval_metric in ['accuracy', 'f1', 'auroc', 'auprc']:
                    if metric_val >= best_metric_val:
                        best_metric_val = metric_val
                        best_epoch = epoch
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                else:
                    raise NotImplementedError

                # 5. Update learning rate scheduler (optional)
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch,
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if 'test_set' in kwargs.keys():
            test_loader = get_dataloader(kwargs.get('test_set'),
                                         batch_size,
                                         num_workers=num_workers)
            self.test(test_loader, device=device, logger=logger)
示例#5
0
    def run(self,
            train_set: torch.utils.data.Dataset,
            valid_set: torch.utils.data.Dataset,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            **kwargs):

        logger = kwargs.get('logger', None)
        save_every = kwargs.get('save_every', epochs)

        self.backbone.to(self.local_rank)
        self.projector.to(self.local_rank)

        if self.distributed:
            raise NotImplementedError
        else:
            train_loader = DataLoader(train_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=True,
                                      pin_memory=False)
            valid_loader = DataLoader(valid_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=True,
                                      pin_memory=False)

        # Initialize memory representations for the training data
        if not self.memory.initialized:
            self.memory.initialize(self.backbone, self.projector, train_loader)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader)
                valid_history = self.evaluate(valid_loader)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    },
                }

                # 2. Epoch history (other metrics if provided)
                if self.metrics is not None:
                    assert isinstance(self.metrics, dict)
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history.get(metric_name),
                            'valid': valid_history.get(metric_name),
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4-1. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    if self.local_rank == 0:
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                        self.memory.save(os.path.join(
                            os.path.dirname(self.best_ckpt), 'best_memory.pt'),
                                         epoch=epoch)

                # 4-2. Save intermediate models
                if epoch % save_every == 0:
                    if self.local_rank == 0:
                        new_ckpt = os.path.join(
                            self.checkpoint_dir,
                            f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt')
                        self.save_checkpoint(new_ckpt,
                                             epoch=epoch,
                                             **epoch_history)

                # 5. Update learning rate scheduler
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(history=epoch_history,
                                              current=epoch,
                                              total=epochs,
                                              best=best_epoch)
                pbar.set_description_str(desc)
                pbar.update(1)

                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        if self.local_rank == 0:
            self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)
            self.memory.save(os.path.join(os.path.dirname(self.last_ckpt),
                                          'last_memory.pt'),
                             epoch=epoch)