示例#1
0
def set_device(_net, ctx, *args, **kwargs):
    if ctx == "cpu":
        if not isinstance(_net, DataParallel):
            _net = DataParallel(_net)
        return _net.cpu()
    elif any(map(lambda x: x in ctx, ["cuda", "gpu"])):  # pragma: no cover
        # todo: find a way to test gpu device
        if not torch.cuda.is_available():
            try:
                torch.ones((1, ), device=torch.device("cuda: 0"))
            except AssertionError as e:
                raise TypeError(
                    "no cuda detected, noly cpu is supported, the detailed error msg:%s"
                    % str(e))
        if torch.cuda.device_count() >= 1:
            if ":" in ctx:
                ctx_name, device_ids = ctx.split(":")
                assert ctx_name in [
                    "cuda", "gpu"
                ], "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx
                device_ids = [int(i) for i in device_ids.strip().split(",")]
                try:
                    if not isinstance(_net, DataParallel):
                        return DataParallel(_net, device_ids).cuda
                    return _net.cuda(device_ids)
                except AssertionError as e:
                    logging.error(device_ids)
                    raise e
            elif ctx in ["cuda", "gpu"]:
                if not isinstance(_net, DataParallel):
                    _net = DataParallel(_net)
                return _net.cuda()
            else:
                raise TypeError(
                    "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s"
                    % ctx)
        else:
            print(torch.cuda.device_count())
            raise TypeError("0 gpu can be used, use cpu")
    else:  # pragma: no cover
        # todo: find a way to test gpu device
        if not isinstance(_net, DataParallel):
            return DataParallel(_net, device_ids=ctx).cuda()
        return _net.cuda(ctx)
            model.train()
        validate_loss /= validate_num
        lr = optimizer.param_groups[0]['lr']
        print('Fold{} Epoch{}:\tValidate-{:.4f}\tlr-{}e-5'.format(
            fold, epoch, validate_loss, lr * 100000.))
        scheduler.step(validate_loss)

        if validate_loss < min_loss:
            min_loss = validate_loss
            early_stop_counter = 0
            if len(device_ids) > 1:
                torch.save(model.module.cpu().state_dict(),
                           os.path.join(save_dir, 'model_{}.pth'.format(fold)))
            else:
                torch.save(model.cpu().state_dict(),
                           os.path.join(save_dir, 'model_{}.pth'.format(fold)))
            model.cuda()
        else:
            early_stop_counter += 1
        if early_stop_counter == early_stop:
            mean_loss += min_loss
            break
    print('Fold{} Stop after training {} epoch'.format(fold,
                                                       epoch - early_stop))
    print('Fold{} Validate Loss:{}'.format(fold, min_loss))
    with open(os.path.join(save_dir, 'config'), 'a') as f:
        f.write('\nFold{} Stop after training {} epoch'.format(
            fold, epoch - early_stop))
        f.write('\nFold{} Validate Loss:{}\n'.format(fold, min_loss))
示例#3
0
def main():
    parser = argparse.ArgumentParser(description='N-Net training')
    parser.add_argument('--preprocess_result_path', help='Directory to save preprocessed _clean and _label .npy files.',
                        default='F:\\LargeFiles\\lfz\\prep_result_sub\\')
    parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                        help='number of data loading workers (default: 32)')
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                        metavar='N', help='mini-batch size (default: 16)')
    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--save_freq', default='10', type=int, metavar='S',
                        help='save frequency')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--save_dir', default='', type=str, metavar='SAVE',
                        help='directory to save checkpoint (default: none)')
    parser.add_argument('--test', default=0, type=int, metavar='TEST',
                        help='1 do test evaluation, 0 not')
    parser.add_argument('--split', default=8, type=int, metavar='SPLIT',
                        help='In the test phase, split the image to 8 parts')
    parser.add_argument('--gpu', default='all', type=str, metavar='N',
                        help='use gpu, set to `none` to use CPU')
    parser.add_argument('--n_test', default=8, type=int, metavar='N',
                        help='number of gpu for test')
    parser.add_argument('--train_ids', default='./dsb/training/detector/kaggleluna_full.npy', type=str,
                        help='Path to the npy file for training scan IDs stored in a Numpy list.')
    parser.add_argument('--val_ids', default='./dsb/training/detector/kaggleluna_full.npy', type=str, # TODO: replace with valsplit.npy when full datasets available.
                        help='Path to the npy file for validation scan IDs stored in a Numpy list.')
    parser.add_argument('--test_ids', default='./dsb/training/detector/full.npy', type=str,
                        help='Path to the npy file for test scan IDs stored in a Numpy list.')
    args = parser.parse_args()

    torch.manual_seed(0)
    use_gpu = False
    if 'none' not in args.gpu.lower() and torch.cuda.is_available():
        use_gpu = True
        torch.cuda.set_device(0)

    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results',save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results',save_dir)

    os.makedirs(save_dir, exist_ok=True)

    logfile = os.path.join(save_dir,'log')
    if use_gpu:
        print('Use GPU for training.')
        n_gpu = setgpu(args.gpu)
        args.n_gpu = n_gpu
        net = net.cuda()
        loss = loss.cuda()
        cudnn.benchmark = True
        net = DataParallel(net)
    else:
        print('Use CPU for training.')
        net = net.cpu()
    datadir = args.preprocess_result_path

    if args.test == 1:
        margin = 32
        sidelen = 144

        split_comber = SplitComb(sidelen,config['max_stride'],config['stride'],margin,config['pad_value'])
        # Test sets.
        dataset = data.DataBowl3Detector(
            datadir,
            args.test_ids,
            config,
            phase='test',
            split_comber=split_comber)
        test_loader = DataLoader(
            dataset,
            batch_size = 1,
            shuffle = False,
            num_workers = args.workers,
            collate_fn = data.collate,
            pin_memory=False)

        test(test_loader, net, get_pbb, save_dir,config, args)
        return

    # Train sets
    dataset = data.DataBowl3Detector(
        datadir,
        args.train_ids,
        config,
        phase = 'train')

    print('batch_size:', args.batch_size)
    train_loader = DataLoader(
        dataset,
        batch_size = args.batch_size,
        shuffle = True,
        num_workers = args.workers,
        pin_memory=True)

    # Validation sets
    dataset = data.DataBowl3Detector(
        datadir,
        args.val_ids,
        config,
        phase = 'val')
    val_loader = DataLoader(
        dataset,
        batch_size = args.batch_size,
        shuffle = False,
        num_workers = args.workers,
        pin_memory=True)

    optimizer = torch.optim.SGD(
        net.parameters(),
        args.lr,
        momentum = 0.9,
        weight_decay = args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr


    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr, args.save_freq, save_dir, args)
        validate(val_loader, net, loss)
示例#4
0
class Trainer(EnforceOverrides, metaclass=ABCMeta):
    def __init__(self,
                 config: wandb.Config,
                 rh: RunHelper,
                 aux_config: Optional[Dict] = None):
        self.config = config  # wandb Config
        self.aux_config = aux_config  # aux dictionary
        self.rh = rh

        # legacy: still allow val_fcn_list to be created in val_step()
        self.val_fcn_list: Optional[List[
            MetricFcn]] = None  # should be set in the implementing subclass (in init)
        self.init_val_fcns()

        # pytorch trainer objects
        self.device: Optional[torch.device] = None
        self.gpu_ids: Optional[List[int]] = None
        self.model: Optional[PreTrainedModel] = None
        self.tokenizer: Optional[PreTrainedTokenizer] = None
        # self.model_parallel = None
        self.model_is_parallelized: bool = False
        self.setup_model_and_device(
        )  # populate 3 above; potentially add special tokens

        self.train_loader: Optional[Union[ClueDataLoaderBatched,
                                          MultiTaskDataLoader]] = None
        self.dev_loader: Optional[ClueDataLoaderBatched] = None
        self.multitask_manager: Optional[
            util_multiloader.
            MultitaskManager] = None  # set if we are doing multitask
        self.setup_dataloaders()

        self.optimizer: Optional[Adafactor] = None
        self.scheduler: Optional[lr_scheduler] = None
        self._setup_optim_sched()

        # trainer state (must go here since ref'd by verify_and_log)
        self.state = TrainInfo(multitask_mgr=self.multitask_manager)

        self.verify_and_log_trainer_info()

        # if we're resuming
        if self.config.ckpt_path:
            if not self.config.no_train:
                assert self.config.resume_train is not None
            # if resume train, train state, optim, and scheduler will be changed
            self.load_from_ckpt(resume_train=self.config.resume_train)

        # todo misc attributes
        # metrics to track is stored in rh.checkpointsaver

    @abstractmethod
    def init_val_fcns(self):
        pass

    @final
    def load_from_ckpt(self, resume_train=False):
        # todo: print where the config dictionaries differ
        # loads model state
        log.info(f'Loading checkpoint: {self.config.ckpt_path}')
        ckpt_dict: CheckpointDict = util_checkpoint.load_ckpt(
            self.config.ckpt_path, self.model, map_location=self.device)
        if not self.config.no_train and resume_train:
            self.optimizer.load_state_dict(ckpt_dict['optimizer'])
            if self.scheduler is not None:
                raise NotImplemented('Resume not implemented for adam')

            # todo: should also set other properties of state
            self.state.resume(epoch=ckpt_dict['epoch'], step=ckpt_dict['step'])

            # if we're resuming with multitask
            if self.config.multitask:
                # todo(hacK): this needs to be cleaned up
                # so that we can actually resume multitask
                # multitask state needs to be moved into trainer state
                # shoudl also check for equivalence of the other params
                # todo: remove this after verifying that everything is okay

                assert self.config.hacky
                assert isinstance(k_hard_reset_warmup_iters_done,
                                  int) and k_hard_reset_warmup_iters_done > 0
                total_warmup_todo = self.multitask_manager.multitask_warmup
                total_warmup_done = k_hard_reset_warmup_iters_done
                warmup_remaining = total_warmup_todo - total_warmup_done
                assert warmup_remaining == self.state.warmup_remaining()

                # before fixing epoch
                # self.state.epoch -= warmup_remaining      # will be incremented before running first epoch in run()

                # trainloader is a multiloader
                # reset its state correctly
                self.train_loader.num_iters = k_hard_reset_warmup_iters_done

                # # josh hack 04/14/2021
                # self.state.resume(epoch=0,                  # go back to epoch 0
                #                   step=ckpt_dict['step'])
                # self._setup_optim_sched()       # reset

                log.info(
                    f'Set up at epoch {self.state.epoch}, with {self.train_loader.warmup_iters}'
                    f' total warmup, and {self.train_loader.num_iters} already done, ie'
                    f'{self.state.warmup_remaining()} warmup todo')

    @final
    def run(self):
        # if not training
        if self.config.no_train:
            log.info(
                f'arg no_train given. Just doing single validation. Setting to epoch == 1'
            )
            self.state.increment_epoch()  # set to epoch == 1
            assert self.state.epoch == k_max_warmup_epochs + 1
            self.val_only()
            return

        log.warning(
            f'For actual train, epochs start at {k_max_warmup_epochs + 1}')
        # main training; includes warmup
        while self.state.epoch < self.config.num_epochs + k_max_warmup_epochs:
            was_last_warmup = self.state.increment_epoch()
            if was_last_warmup and self.multitask_manager.multitask_reset:
                log.info('Final warmup epoch done. Resetting optimizer')
                self._setup_optim_sched()

            self.train_step()

            # Validate; this will do both multitask and normal validation
            all_metrics = self.val_step()
            metrics_dict, preds = self.metrics_list_to_dict(all_metrics)
            # will have multisave appended if it is multitasking
            self.save_callback(metrics_dict, preds)
            if self.early_stopping_callback(metrics_dict):
                break

        # e.g., final eval
        self.post_run()

    def val_only(self):
        metrics: List[MetricsPredsWrapper] = self.val_step()
        metrics_dict, preds = self.metrics_list_to_dict(metrics)
        self.save_callback(metrics_dict, preds)

    # does not have to be implemented by subclasses
    def post_run(self):
        pass

    @final
    def setup_model_and_device(self):
        # device will be cuda:0
        self.device, self.gpu_ids = util.get_available_devices(
            assert_cuda=True)
        assert str(self.device) == "cuda:0", f'{self.device} != cuda:0'

        if len(self.gpu_ids
               ) > 1 or self.config.multi_gpu is not None or k_data_parallel:
            logging.info(
                f'{len(self.gpu_ids)}, {self.config.multi_gpu}, {k_data_parallel}'
            )
            assert k_data_parallel
            assert len(self.gpu_ids) == self.config.multi_gpu

        self._setup_model_and_tokenizer()  # implemented by subclasses

        if self.config.add_special_tokens:
            util.add_special_tokens(self.model, self.tokenizer)  # adds <SEP>

        self.model_to_device()

    # todo: we might be able to omit this and just have it in the self.train() function
    @final
    def model_to_device(self):
        if k_data_parallel:
            log.info('Using dataparallel')
            self.model = DataParallel(self.model, device_ids=self.gpu_ids)
            self.model_is_parallelized = True

        self.model.to(self.device)

    @abstractmethod
    def _setup_model_and_tokenizer(self):
        """
        Should load and make any tweaks (e.g. vocab changes) the following
        - self.model
        - self.tokenizer
    
        Called by setup_model_and device()
        """
        pass

    @abstractmethod
    def setup_dataloaders(self):
        pass

    @abstractmethod
    def _get_dataloaders(self):
        """
        Should set
        - train_loader
        - dev_loader
        """
        pass

    @abstractmethod
    def _setup_optim_sched(self):
        """
        Should set
        - optimizer
        - scheduler (optional)
        """
        pass

    def verify_and_log_trainer_info(self):
        # verify that the metric we want to log is valid
        log.info(
            'Verifying that all metrics are OK. The outputs here are NOT from the model that was passed if'
            'one was passed')
        metrics_dict, _ = self.metrics_list_to_dict(
            self.val_step(trial_run=True))
        for m in self.rh.metrics_to_track:
            if m[0] in ['epoch'
                        ]:  # these won't be in the normal metrics returned
                continue
            assert m[0] in metrics_dict, f'{m} not in {metrics_dict}'

        log.info(f'Tracking metrics {self.rh.metrics_to_track} all verified')

        # verify everything else
        assert all(
            map(lambda x: x is not None, [
                self.config, self.model, self.tokenizer, self.device,
                self.optimizer, self.train_loader, self.dev_loader
            ]))

        # validation freq
        if self.config.val_freq is not None:
            assert self.config.val_freq * 1000 < self.config.num_train
            # we log as {epoch}.{intermed/100} so max is 100 99
            assert self.config.num_train / (self.config.val_freq * 1000) < 100

        log_string = '\n' \
                      f'total_train_steps (num_train_ex * epochs): {self.config.total_train}\n' \
                      f'machine: {socket.gethostname()}\n' \
                      f'num_train: {self.config.num_train}\n' \
                      f'num_val: {self.config.num_val}'
        # log_string += f'total_optim_steps: {self.config.total_optim_steps}\n' \

        # can't use json for first config dict because not of type dic
        for k, v in sorted(self.config.items(), key=lambda x: x[0]):
            log_string += f'{k}: {v}\n'
        if self.aux_config:
            log_string += "multitask:\n"
            log_string += json.dumps(self.aux_config,
                                     sort_keys=True,
                                     indent=2,
                                     cls=util_dataloader.EnhancedJSONEncoder)
        else:
            log_string += "No aux config (e.g. multitask) given"
        log_string += "\n"
        log.info(log_string)

    @abstractmethod
    def _batch_to_objects(self, batch) -> ProcessedBatch:
        pass

    def val_end_epoch(self,
                      metrics_all_accum: Union[List[MetricsPredsWrapper],
                                               MetricsPredsWrapper],
                      num_val=None):
        if isinstance(metrics_all_accum, MetricsPredsWrapper):
            metrics_all_accum = [metrics_all_accum]

        for m_dict in metrics_all_accum:
            # get_all_metrics will already have a <val_label>:<set_label>:
            for k, v, orig_v in m_dict.get_all_metrics(num_val):
                # Log val and avg val
                log.info(f'{k}: {orig_v:05.2f}\t avg: {v:05.4f}')
                # util.log_scalar(f'{k}', v/self.config.num_val, self.state.epoch, tbx=self.rh.tbx)
                util.log_wandb_new({f'{k}': v},
                                   step=self.state.step,
                                   epoch=self.state.epoch,
                                   use_step_for_logging=k_use_step_for_logging)

    @abstractmethod
    def model_forward(self, src_ids: torch.Tensor, src_mask: torch.Tensor, tgt_ids: torch.Tensor) -> \
        Tuple[torch.Tensor, Dict]:
        pass

    @abstractmethod
    def train_step(self) -> NoReturn:
        # will generally need to call model_forward method
        pass

    @abstractmethod
    def _generate_outputs_greedy(self,
                                 src_ids,
                                 src_mask,
                                 skip_special_tokens=True) -> Tuple:
        pass

    @abstractmethod
    def _generate_outputs_sampled(self, src_ids, src_mask, batch_size) -> List:
        pass

    def get_valstepdict_for_batch(self,
                                  pbatch: ProcessedBatch,
                                  do_sample: bool,
                                  do_generate: bool = True) -> PerBatchValStep:
        # evaluation for loss fcn
        perbatch_valstep = PerBatchValStep()
        loss, _ = self.model_forward(
            pbatch.src_ids, pbatch.src_mask,
            pbatch.tgt_ids)  # loss, logits, but don't need logits
        if k_data_parallel:
            loss = loss.mean()
        perbatch_valstep.loss_val = loss.detach().item()

        if do_generate:
            outputs_decoded_greedy, generated_ids_greedy = \
                self._generate_outputs_greedy(pbatch.src_ids, pbatch.src_mask)
            perbatch_valstep.outputs_greedy = outputs_decoded_greedy
            perbatch_valstep.outputs_greedy_ids = generated_ids_greedy

        if do_sample:
            outputs_decoded_sampled = \
                self._generate_outputs_sampled(pbatch.src_ids, pbatch.src_mask, pbatch.batch_size)
            perbatch_valstep.outputs_sampled = outputs_decoded_sampled

        return perbatch_valstep

    def val_step(self, trial_run: bool = False) -> List[MetricsPredsWrapper]:
        """
        :param trial_run: whether this is an initial check run - only one batch will be computed
        :return:
        """
        log.info(
            f'Evaluating at all_step {self.state.step} (epoch={self.state.epoch})...'
        )
        self.eval()
        # self.model.eval()  # put model in eval mode

        # accumulate all metrics over all of the val_dls
        all_metrics_wrappers: List[MetricsPredsWrapper] = []

        # if not self.state.epoch > 0 or trial_run:     # not warmup
        if not self.state.is_warmup() or trial_run:
            log.info(f'Primary eval; epoch: {self.state.epoch}')
            metrics_accum = self.validate_val_loader(self.dev_loader,
                                                     self.val_fcn_list,
                                                     trial_run,
                                                     label='dev',
                                                     do_print=True)
            all_metrics_wrappers.append(metrics_accum)

        # always do multitask
        if self.config.multitask:
            log.info(f'Multitask eval; epoch: {self.state.epoch}')
            for val in self.multitask_manager.val_dls:
                log.info(f'Validating DL {val.name}')
                metrics_accum = self.validate_val_loader(
                    val.dataloader,
                    val.val_fcn_list,
                    trial_run=trial_run,
                    label=f'multi/{val.name}',
                    do_print=False)
                # we don't save predictions from the multiloaders
                metrics_accum.preds = None
                all_metrics_wrappers.append(metrics_accum)

        assert len(
            all_metrics_wrappers) > 0, 'Val step called with invalid params'

        if not trial_run:
            self.val_end_epoch(
                all_metrics_wrappers,
                num_val=None)  # use the avg divisor as set in the constructor
        return all_metrics_wrappers

    def validate_val_loader(self, val_loader: ClueDataLoaderBatched,
                            val_fcn: List[Callable], trial_run: bool,
                            label: str, do_print: bool):
        metrics_all_accum: MetricsPredsWrapper = MetricsPredsWrapper(
            label=label, avg_divisor=self.dev_loader.num_examples())
        loss_meter = util.AverageMeter(
        )  # NLL (default metric for model) (reset each time)

        # todo: should total be num_examples or num_val
        with torch.no_grad(), \
            tqdm(total=val_loader.num_examples()) as progress_bar:
            for batch_num, batch in enumerate(val_loader):
                # run a single batch and then return
                if trial_run and batch_num > 0:
                    break

                pbatch = self._batch_to_objects(batch)
                valstepbatch = self.get_valstepdict_for_batch(
                    pbatch, do_sample=self.config.do_sample)

                # update metrics and predictions tracking
                metrics_all_accum.update_for_batch(val_fcn,
                                                   valstepbatch,
                                                   pbatch,
                                                   metric_label='')

                loss_meter.update(valstepbatch.loss_val, pbatch.batch_size)
                progress_bar.update(pbatch.batch_size)
                progress_bar.set_postfix(NLL=loss_meter.avg)

                # On first batch print one batch of generations for qualitative assessment
                if do_print and batch_num == 0:
                    for idx, orig_input, orig_target, output_greedy, *other in metrics_all_accum.preds[:
                                                                                                       1]:
                        log.info(f'\n idx: {idx}'
                                 f'\nSource: {orig_input}\n '
                                 f'\tTarget: {orig_target}\n'
                                 f'\t Actual: {output_greedy}\n')

        # append the NLL to the metrics
        metrics_all_accum.add_val('NLL', loss_meter.avg, avg=False, label='')

        return metrics_all_accum

    ##
    # Other helper functions
    ###
    def metrics_list_to_dict(
            self,
            metrics_wrappers: List[MetricsPredsWrapper]) -> Tuple[Dict, List]:
        all_metrics_dict = dict()

        # we should have only a single set of preds; this is hacky. we set all multiloader
        # preds to None during val_step() which is the only time this MetricsPredswrappers are produced
        preds = None
        for m in metrics_wrappers:
            all_metrics_dict.update(m.get_all_metrics_dict())
            if m.preds is not None:
                assert preds is None
                preds = m.preds

        # could also do this; but then change the code in util_checkpoitn
        # all_metrics_dict.update(dict(epoch=self.state.epoch))
        # todo: hacky
        # if self.config.multitask and self.state.epoch <= 0:
        if self.config.multitask and self.state.is_warmup():
            all_metrics_dict.update(dict(multisave=self.state.epoch))

        return all_metrics_dict, preds

    def save_callback(self, metrics_dict, preds, intermed_epoch=None):
        # save_metrics = metrics.get_all_metrics_dict()
        # save_preds = metrics.preds
        if intermed_epoch is not None:
            save_epoch = self.state.epoch + intermed_epoch / 100
        else:
            save_epoch = self.state.epoch
        self.rh.ckpt_saver.save_if_best(
            save_epoch,
            self,
            # metric_dict=save_metrics,
            metric_dict=metrics_dict,
            # preds=save_preds,
            preds=preds,
            save_model=self.config.do_save)

    # todo(wrong): support max/min metrics
    def early_stopping_callback(self, metrics_dict: Dict):
        if not self.config.early_stopping:
            return False
        if self.config.early_stopping not in metrics_dict:
            log.warning(
                f'Early stopping but metric {self.config.early_stopping} not found'
            )
            return False
        curr_metric = metrics_dict[self.config.early_stopping]
        if self.state.metric_best is not None:
            if self.state.metric_best < curr_metric:
                log.info(
                    f"Early stopping: prev {self.state.metric_best}\t current: {curr_metric}"
                )
                return True
            else:
                log.info(
                    f"Not stopping: prev {self.state.metric_best}\t current: {curr_metric}"
                )
        # otherwise store new best
        self.state.metric_best = curr_metric
        return False

    def make_ckpt_dict(self) -> CheckpointDict:
        self.model.cpu(
        )  # todo(parallel): verify this isn't necessary for save
        model_for_ckpt = self._model_for_ckpt()

        sched = None
        if self.scheduler is not None:
            sched = self.scheduler.state_dict()

        ckpt_dict: CheckpointDict = {
            # 'model_state': self.model.state_dict(),
            'model_state': model_for_ckpt.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': sched,
            'config': dict(self.config.items()
                           ),  # todo: fix this (so that it can be reloaded)
            'step': self.state.step,
            'epoch': self.state.epoch
        }
        # was needed when we did self.model.cpu()
        self.model.to(self.device)
        return ckpt_dict

    def _model_for_ckpt(self):
        #todo(parallel): verify don't need cpu
        if k_data_parallel:
            return self.model.module
        else:
            return self.model
        # model_for_save = self.model.cpu()
        # return model_for_save

    def eval(self):
        if k_data_parallel and self.model_is_parallelized:
            self.model = self.model.module
            self.model_is_parallelized = False
            # self.model.to(self.device)      # todo(parallel): do we need this?
        self.model.eval()  # put model in eval mode

    def train(self):
        if k_data_parallel and not self.model_is_parallelized:
            self.model = DataParallel(self.model, self.gpu_ids)
            self.model_is_parallelized = True
            #self.model.to(self.device)      # todo(parallel): do we need this?
        self.model.train()
示例#5
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    # TODO: uncomment to use GPU for training.
    # torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    if 'none' not in args.gpu.lower() and torch.cuda.is_available():
        print('Use GPU for training.')
        n_gpu = setgpu(args.gpu)
        args.n_gpu = n_gpu
        net = net.cuda()
        loss = loss.cuda()
        cudnn.benchmark = True
        net = DataParallel(net)
    else:
        print('Use CPU for training.')
        net = net.cpu()
    datadir = config_training['preprocess_result_path']

    if args.test == 1:
        margin = 32
        sidelen = 144

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(datadir,
                                         'full.npy',
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        test(test_loader, net, get_pbb, save_dir, config)
        return

    #net = DataParallel(net)

    dataset = data.DataBowl3Detector(datadir,
                                     'kaggleluna_full.npy',
                                     config,
                                     phase='train')

    print('batch_size:', args.batch_size)
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(
        datadir,
        'kaggleluna_full.npy',  # TODO: replace with valsplit.npy when full datasets available.
        config,
        phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#6
0
def SPNNetTrain(context):
    torch.manual_seed(0)

    args = context.args

    saveFolder = os.path.dirname(args.outputCheckpoint)
    dataFolder = args.inputDataFolder
    checkoutPointPath = args.inputCheckpoint
    trainIds = args.inputTrainData[args.idColumn]
    validateIds = args.inputValidateData[args.idColumn]
    workers = asyncio.WORKERS
    epochs = args.epochs
    batchSize = args.batchSize
    learningRate = args.learningRate
    momentum = args.momentum
    weightDecay = args.weightDecay
    useGpu = torch.cuda.is_available()

    config, net, loss, getPbb = model.get_model()

    if checkoutPointPath:
        checkpoint = torch.load(checkoutPointPath)
        startEpoch = checkpoint["epoch"] + 1
        net.load_state_dict(checkpoint["state_dict"])
    else:
        startEpoch = 1

    if useGpu:
        print("Use GPU {} for training.".format(torch.cuda.current_device()))
        net = net.cuda()
        loss = loss.cuda()
        cudnn.benchmark = True
        net = DataParallel(net)
    else:
        print("Use CPU for training.")
        net = net.cpu()

    # Train sets
    dataset = data.DataBowl3Detector(dataFolder,
                                     trainIds,
                                     config,
                                     phase="train")
    trainLoader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )

    # Validation sets
    dataset = data.DataBowl3Detector(dataFolder,
                                     validateIds,
                                     config,
                                     phase="val")
    valLoader = DataLoader(
        dataset,
        batch_size=batchSize,
        shuffle=False,
        num_workers=workers,
        pin_memory=True,
    )
    optimizer = torch.optim.SGD(net.parameters(),
                                learningRate,
                                momentum=momentum,
                                weight_decay=weightDecay)

    getlr = functools.partial(getLearningRate, epochs=epochs, lr=learningRate)

    for epoch in range(startEpoch, epochs + 1):
        train(trainLoader, net, loss, epoch, optimizer, getlr, saveFolder)
        validate(valLoader, net, loss)

    ckptPath = os.path.join(saveFolder, "model.ckpt")
    save(ckptPath, net, epochs=epochs)

    return ckptPath