def set_device(_net, ctx, *args, **kwargs): if ctx == "cpu": if not isinstance(_net, DataParallel): _net = DataParallel(_net) return _net.cpu() elif any(map(lambda x: x in ctx, ["cuda", "gpu"])): # pragma: no cover # todo: find a way to test gpu device if not torch.cuda.is_available(): try: torch.ones((1, ), device=torch.device("cuda: 0")) except AssertionError as e: raise TypeError( "no cuda detected, noly cpu is supported, the detailed error msg:%s" % str(e)) if torch.cuda.device_count() >= 1: if ":" in ctx: ctx_name, device_ids = ctx.split(":") assert ctx_name in [ "cuda", "gpu" ], "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx device_ids = [int(i) for i in device_ids.strip().split(",")] try: if not isinstance(_net, DataParallel): return DataParallel(_net, device_ids).cuda return _net.cuda(device_ids) except AssertionError as e: logging.error(device_ids) raise e elif ctx in ["cuda", "gpu"]: if not isinstance(_net, DataParallel): _net = DataParallel(_net) return _net.cuda() else: raise TypeError( "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx) else: print(torch.cuda.device_count()) raise TypeError("0 gpu can be used, use cpu") else: # pragma: no cover # todo: find a way to test gpu device if not isinstance(_net, DataParallel): return DataParallel(_net, device_ids=ctx).cuda() return _net.cuda(ctx)
model.train() validate_loss /= validate_num lr = optimizer.param_groups[0]['lr'] print('Fold{} Epoch{}:\tValidate-{:.4f}\tlr-{}e-5'.format( fold, epoch, validate_loss, lr * 100000.)) scheduler.step(validate_loss) if validate_loss < min_loss: min_loss = validate_loss early_stop_counter = 0 if len(device_ids) > 1: torch.save(model.module.cpu().state_dict(), os.path.join(save_dir, 'model_{}.pth'.format(fold))) else: torch.save(model.cpu().state_dict(), os.path.join(save_dir, 'model_{}.pth'.format(fold))) model.cuda() else: early_stop_counter += 1 if early_stop_counter == early_stop: mean_loss += min_loss break print('Fold{} Stop after training {} epoch'.format(fold, epoch - early_stop)) print('Fold{} Validate Loss:{}'.format(fold, min_loss)) with open(os.path.join(save_dir, 'config'), 'a') as f: f.write('\nFold{} Stop after training {} epoch'.format( fold, epoch - early_stop)) f.write('\nFold{} Validate Loss:{}\n'.format(fold, min_loss))
def main(): parser = argparse.ArgumentParser(description='N-Net training') parser.add_argument('--preprocess_result_path', help='Directory to save preprocessed _clean and _label .npy files.', default='F:\\LargeFiles\\lfz\\prep_result_sub\\') parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', help='number of data loading workers (default: 32)') parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N', help='mini-batch size (default: 16)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--save_freq', default='10', type=int, metavar='S', help='save frequency') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--save_dir', default='', type=str, metavar='SAVE', help='directory to save checkpoint (default: none)') parser.add_argument('--test', default=0, type=int, metavar='TEST', help='1 do test evaluation, 0 not') parser.add_argument('--split', default=8, type=int, metavar='SPLIT', help='In the test phase, split the image to 8 parts') parser.add_argument('--gpu', default='all', type=str, metavar='N', help='use gpu, set to `none` to use CPU') parser.add_argument('--n_test', default=8, type=int, metavar='N', help='number of gpu for test') parser.add_argument('--train_ids', default='./dsb/training/detector/kaggleluna_full.npy', type=str, help='Path to the npy file for training scan IDs stored in a Numpy list.') parser.add_argument('--val_ids', default='./dsb/training/detector/kaggleluna_full.npy', type=str, # TODO: replace with valsplit.npy when full datasets available. help='Path to the npy file for validation scan IDs stored in a Numpy list.') parser.add_argument('--test_ids', default='./dsb/training/detector/full.npy', type=str, help='Path to the npy file for test scan IDs stored in a Numpy list.') args = parser.parse_args() torch.manual_seed(0) use_gpu = False if 'none' not in args.gpu.lower() and torch.cuda.is_available(): use_gpu = True torch.cuda.set_device(0) config, net, loss, get_pbb = model.get_model() start_epoch = args.start_epoch save_dir = args.save_dir if args.resume: checkpoint = torch.load(args.resume) if start_epoch == 0: start_epoch = checkpoint['epoch'] + 1 if not save_dir: save_dir = checkpoint['save_dir'] else: save_dir = os.path.join('results',save_dir) net.load_state_dict(checkpoint['state_dict']) else: if start_epoch == 0: start_epoch = 1 if not save_dir: exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime()) save_dir = os.path.join('results', args.model + '-' + exp_id) else: save_dir = os.path.join('results',save_dir) os.makedirs(save_dir, exist_ok=True) logfile = os.path.join(save_dir,'log') if use_gpu: print('Use GPU for training.') n_gpu = setgpu(args.gpu) args.n_gpu = n_gpu net = net.cuda() loss = loss.cuda() cudnn.benchmark = True net = DataParallel(net) else: print('Use CPU for training.') net = net.cpu() datadir = args.preprocess_result_path if args.test == 1: margin = 32 sidelen = 144 split_comber = SplitComb(sidelen,config['max_stride'],config['stride'],margin,config['pad_value']) # Test sets. dataset = data.DataBowl3Detector( datadir, args.test_ids, config, phase='test', split_comber=split_comber) test_loader = DataLoader( dataset, batch_size = 1, shuffle = False, num_workers = args.workers, collate_fn = data.collate, pin_memory=False) test(test_loader, net, get_pbb, save_dir,config, args) return # Train sets dataset = data.DataBowl3Detector( datadir, args.train_ids, config, phase = 'train') print('batch_size:', args.batch_size) train_loader = DataLoader( dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.workers, pin_memory=True) # Validation sets dataset = data.DataBowl3Detector( datadir, args.val_ids, config, phase = 'val') val_loader = DataLoader( dataset, batch_size = args.batch_size, shuffle = False, num_workers = args.workers, pin_memory=True) optimizer = torch.optim.SGD( net.parameters(), args.lr, momentum = 0.9, weight_decay = args.weight_decay) def get_lr(epoch): if epoch <= args.epochs * 0.5: lr = args.lr elif epoch <= args.epochs * 0.8: lr = 0.1 * args.lr else: lr = 0.01 * args.lr return lr for epoch in range(start_epoch, args.epochs + 1): train(train_loader, net, loss, epoch, optimizer, get_lr, args.save_freq, save_dir, args) validate(val_loader, net, loss)
class Trainer(EnforceOverrides, metaclass=ABCMeta): def __init__(self, config: wandb.Config, rh: RunHelper, aux_config: Optional[Dict] = None): self.config = config # wandb Config self.aux_config = aux_config # aux dictionary self.rh = rh # legacy: still allow val_fcn_list to be created in val_step() self.val_fcn_list: Optional[List[ MetricFcn]] = None # should be set in the implementing subclass (in init) self.init_val_fcns() # pytorch trainer objects self.device: Optional[torch.device] = None self.gpu_ids: Optional[List[int]] = None self.model: Optional[PreTrainedModel] = None self.tokenizer: Optional[PreTrainedTokenizer] = None # self.model_parallel = None self.model_is_parallelized: bool = False self.setup_model_and_device( ) # populate 3 above; potentially add special tokens self.train_loader: Optional[Union[ClueDataLoaderBatched, MultiTaskDataLoader]] = None self.dev_loader: Optional[ClueDataLoaderBatched] = None self.multitask_manager: Optional[ util_multiloader. MultitaskManager] = None # set if we are doing multitask self.setup_dataloaders() self.optimizer: Optional[Adafactor] = None self.scheduler: Optional[lr_scheduler] = None self._setup_optim_sched() # trainer state (must go here since ref'd by verify_and_log) self.state = TrainInfo(multitask_mgr=self.multitask_manager) self.verify_and_log_trainer_info() # if we're resuming if self.config.ckpt_path: if not self.config.no_train: assert self.config.resume_train is not None # if resume train, train state, optim, and scheduler will be changed self.load_from_ckpt(resume_train=self.config.resume_train) # todo misc attributes # metrics to track is stored in rh.checkpointsaver @abstractmethod def init_val_fcns(self): pass @final def load_from_ckpt(self, resume_train=False): # todo: print where the config dictionaries differ # loads model state log.info(f'Loading checkpoint: {self.config.ckpt_path}') ckpt_dict: CheckpointDict = util_checkpoint.load_ckpt( self.config.ckpt_path, self.model, map_location=self.device) if not self.config.no_train and resume_train: self.optimizer.load_state_dict(ckpt_dict['optimizer']) if self.scheduler is not None: raise NotImplemented('Resume not implemented for adam') # todo: should also set other properties of state self.state.resume(epoch=ckpt_dict['epoch'], step=ckpt_dict['step']) # if we're resuming with multitask if self.config.multitask: # todo(hacK): this needs to be cleaned up # so that we can actually resume multitask # multitask state needs to be moved into trainer state # shoudl also check for equivalence of the other params # todo: remove this after verifying that everything is okay assert self.config.hacky assert isinstance(k_hard_reset_warmup_iters_done, int) and k_hard_reset_warmup_iters_done > 0 total_warmup_todo = self.multitask_manager.multitask_warmup total_warmup_done = k_hard_reset_warmup_iters_done warmup_remaining = total_warmup_todo - total_warmup_done assert warmup_remaining == self.state.warmup_remaining() # before fixing epoch # self.state.epoch -= warmup_remaining # will be incremented before running first epoch in run() # trainloader is a multiloader # reset its state correctly self.train_loader.num_iters = k_hard_reset_warmup_iters_done # # josh hack 04/14/2021 # self.state.resume(epoch=0, # go back to epoch 0 # step=ckpt_dict['step']) # self._setup_optim_sched() # reset log.info( f'Set up at epoch {self.state.epoch}, with {self.train_loader.warmup_iters}' f' total warmup, and {self.train_loader.num_iters} already done, ie' f'{self.state.warmup_remaining()} warmup todo') @final def run(self): # if not training if self.config.no_train: log.info( f'arg no_train given. Just doing single validation. Setting to epoch == 1' ) self.state.increment_epoch() # set to epoch == 1 assert self.state.epoch == k_max_warmup_epochs + 1 self.val_only() return log.warning( f'For actual train, epochs start at {k_max_warmup_epochs + 1}') # main training; includes warmup while self.state.epoch < self.config.num_epochs + k_max_warmup_epochs: was_last_warmup = self.state.increment_epoch() if was_last_warmup and self.multitask_manager.multitask_reset: log.info('Final warmup epoch done. Resetting optimizer') self._setup_optim_sched() self.train_step() # Validate; this will do both multitask and normal validation all_metrics = self.val_step() metrics_dict, preds = self.metrics_list_to_dict(all_metrics) # will have multisave appended if it is multitasking self.save_callback(metrics_dict, preds) if self.early_stopping_callback(metrics_dict): break # e.g., final eval self.post_run() def val_only(self): metrics: List[MetricsPredsWrapper] = self.val_step() metrics_dict, preds = self.metrics_list_to_dict(metrics) self.save_callback(metrics_dict, preds) # does not have to be implemented by subclasses def post_run(self): pass @final def setup_model_and_device(self): # device will be cuda:0 self.device, self.gpu_ids = util.get_available_devices( assert_cuda=True) assert str(self.device) == "cuda:0", f'{self.device} != cuda:0' if len(self.gpu_ids ) > 1 or self.config.multi_gpu is not None or k_data_parallel: logging.info( f'{len(self.gpu_ids)}, {self.config.multi_gpu}, {k_data_parallel}' ) assert k_data_parallel assert len(self.gpu_ids) == self.config.multi_gpu self._setup_model_and_tokenizer() # implemented by subclasses if self.config.add_special_tokens: util.add_special_tokens(self.model, self.tokenizer) # adds <SEP> self.model_to_device() # todo: we might be able to omit this and just have it in the self.train() function @final def model_to_device(self): if k_data_parallel: log.info('Using dataparallel') self.model = DataParallel(self.model, device_ids=self.gpu_ids) self.model_is_parallelized = True self.model.to(self.device) @abstractmethod def _setup_model_and_tokenizer(self): """ Should load and make any tweaks (e.g. vocab changes) the following - self.model - self.tokenizer Called by setup_model_and device() """ pass @abstractmethod def setup_dataloaders(self): pass @abstractmethod def _get_dataloaders(self): """ Should set - train_loader - dev_loader """ pass @abstractmethod def _setup_optim_sched(self): """ Should set - optimizer - scheduler (optional) """ pass def verify_and_log_trainer_info(self): # verify that the metric we want to log is valid log.info( 'Verifying that all metrics are OK. The outputs here are NOT from the model that was passed if' 'one was passed') metrics_dict, _ = self.metrics_list_to_dict( self.val_step(trial_run=True)) for m in self.rh.metrics_to_track: if m[0] in ['epoch' ]: # these won't be in the normal metrics returned continue assert m[0] in metrics_dict, f'{m} not in {metrics_dict}' log.info(f'Tracking metrics {self.rh.metrics_to_track} all verified') # verify everything else assert all( map(lambda x: x is not None, [ self.config, self.model, self.tokenizer, self.device, self.optimizer, self.train_loader, self.dev_loader ])) # validation freq if self.config.val_freq is not None: assert self.config.val_freq * 1000 < self.config.num_train # we log as {epoch}.{intermed/100} so max is 100 99 assert self.config.num_train / (self.config.val_freq * 1000) < 100 log_string = '\n' \ f'total_train_steps (num_train_ex * epochs): {self.config.total_train}\n' \ f'machine: {socket.gethostname()}\n' \ f'num_train: {self.config.num_train}\n' \ f'num_val: {self.config.num_val}' # log_string += f'total_optim_steps: {self.config.total_optim_steps}\n' \ # can't use json for first config dict because not of type dic for k, v in sorted(self.config.items(), key=lambda x: x[0]): log_string += f'{k}: {v}\n' if self.aux_config: log_string += "multitask:\n" log_string += json.dumps(self.aux_config, sort_keys=True, indent=2, cls=util_dataloader.EnhancedJSONEncoder) else: log_string += "No aux config (e.g. multitask) given" log_string += "\n" log.info(log_string) @abstractmethod def _batch_to_objects(self, batch) -> ProcessedBatch: pass def val_end_epoch(self, metrics_all_accum: Union[List[MetricsPredsWrapper], MetricsPredsWrapper], num_val=None): if isinstance(metrics_all_accum, MetricsPredsWrapper): metrics_all_accum = [metrics_all_accum] for m_dict in metrics_all_accum: # get_all_metrics will already have a <val_label>:<set_label>: for k, v, orig_v in m_dict.get_all_metrics(num_val): # Log val and avg val log.info(f'{k}: {orig_v:05.2f}\t avg: {v:05.4f}') # util.log_scalar(f'{k}', v/self.config.num_val, self.state.epoch, tbx=self.rh.tbx) util.log_wandb_new({f'{k}': v}, step=self.state.step, epoch=self.state.epoch, use_step_for_logging=k_use_step_for_logging) @abstractmethod def model_forward(self, src_ids: torch.Tensor, src_mask: torch.Tensor, tgt_ids: torch.Tensor) -> \ Tuple[torch.Tensor, Dict]: pass @abstractmethod def train_step(self) -> NoReturn: # will generally need to call model_forward method pass @abstractmethod def _generate_outputs_greedy(self, src_ids, src_mask, skip_special_tokens=True) -> Tuple: pass @abstractmethod def _generate_outputs_sampled(self, src_ids, src_mask, batch_size) -> List: pass def get_valstepdict_for_batch(self, pbatch: ProcessedBatch, do_sample: bool, do_generate: bool = True) -> PerBatchValStep: # evaluation for loss fcn perbatch_valstep = PerBatchValStep() loss, _ = self.model_forward( pbatch.src_ids, pbatch.src_mask, pbatch.tgt_ids) # loss, logits, but don't need logits if k_data_parallel: loss = loss.mean() perbatch_valstep.loss_val = loss.detach().item() if do_generate: outputs_decoded_greedy, generated_ids_greedy = \ self._generate_outputs_greedy(pbatch.src_ids, pbatch.src_mask) perbatch_valstep.outputs_greedy = outputs_decoded_greedy perbatch_valstep.outputs_greedy_ids = generated_ids_greedy if do_sample: outputs_decoded_sampled = \ self._generate_outputs_sampled(pbatch.src_ids, pbatch.src_mask, pbatch.batch_size) perbatch_valstep.outputs_sampled = outputs_decoded_sampled return perbatch_valstep def val_step(self, trial_run: bool = False) -> List[MetricsPredsWrapper]: """ :param trial_run: whether this is an initial check run - only one batch will be computed :return: """ log.info( f'Evaluating at all_step {self.state.step} (epoch={self.state.epoch})...' ) self.eval() # self.model.eval() # put model in eval mode # accumulate all metrics over all of the val_dls all_metrics_wrappers: List[MetricsPredsWrapper] = [] # if not self.state.epoch > 0 or trial_run: # not warmup if not self.state.is_warmup() or trial_run: log.info(f'Primary eval; epoch: {self.state.epoch}') metrics_accum = self.validate_val_loader(self.dev_loader, self.val_fcn_list, trial_run, label='dev', do_print=True) all_metrics_wrappers.append(metrics_accum) # always do multitask if self.config.multitask: log.info(f'Multitask eval; epoch: {self.state.epoch}') for val in self.multitask_manager.val_dls: log.info(f'Validating DL {val.name}') metrics_accum = self.validate_val_loader( val.dataloader, val.val_fcn_list, trial_run=trial_run, label=f'multi/{val.name}', do_print=False) # we don't save predictions from the multiloaders metrics_accum.preds = None all_metrics_wrappers.append(metrics_accum) assert len( all_metrics_wrappers) > 0, 'Val step called with invalid params' if not trial_run: self.val_end_epoch( all_metrics_wrappers, num_val=None) # use the avg divisor as set in the constructor return all_metrics_wrappers def validate_val_loader(self, val_loader: ClueDataLoaderBatched, val_fcn: List[Callable], trial_run: bool, label: str, do_print: bool): metrics_all_accum: MetricsPredsWrapper = MetricsPredsWrapper( label=label, avg_divisor=self.dev_loader.num_examples()) loss_meter = util.AverageMeter( ) # NLL (default metric for model) (reset each time) # todo: should total be num_examples or num_val with torch.no_grad(), \ tqdm(total=val_loader.num_examples()) as progress_bar: for batch_num, batch in enumerate(val_loader): # run a single batch and then return if trial_run and batch_num > 0: break pbatch = self._batch_to_objects(batch) valstepbatch = self.get_valstepdict_for_batch( pbatch, do_sample=self.config.do_sample) # update metrics and predictions tracking metrics_all_accum.update_for_batch(val_fcn, valstepbatch, pbatch, metric_label='') loss_meter.update(valstepbatch.loss_val, pbatch.batch_size) progress_bar.update(pbatch.batch_size) progress_bar.set_postfix(NLL=loss_meter.avg) # On first batch print one batch of generations for qualitative assessment if do_print and batch_num == 0: for idx, orig_input, orig_target, output_greedy, *other in metrics_all_accum.preds[: 1]: log.info(f'\n idx: {idx}' f'\nSource: {orig_input}\n ' f'\tTarget: {orig_target}\n' f'\t Actual: {output_greedy}\n') # append the NLL to the metrics metrics_all_accum.add_val('NLL', loss_meter.avg, avg=False, label='') return metrics_all_accum ## # Other helper functions ### def metrics_list_to_dict( self, metrics_wrappers: List[MetricsPredsWrapper]) -> Tuple[Dict, List]: all_metrics_dict = dict() # we should have only a single set of preds; this is hacky. we set all multiloader # preds to None during val_step() which is the only time this MetricsPredswrappers are produced preds = None for m in metrics_wrappers: all_metrics_dict.update(m.get_all_metrics_dict()) if m.preds is not None: assert preds is None preds = m.preds # could also do this; but then change the code in util_checkpoitn # all_metrics_dict.update(dict(epoch=self.state.epoch)) # todo: hacky # if self.config.multitask and self.state.epoch <= 0: if self.config.multitask and self.state.is_warmup(): all_metrics_dict.update(dict(multisave=self.state.epoch)) return all_metrics_dict, preds def save_callback(self, metrics_dict, preds, intermed_epoch=None): # save_metrics = metrics.get_all_metrics_dict() # save_preds = metrics.preds if intermed_epoch is not None: save_epoch = self.state.epoch + intermed_epoch / 100 else: save_epoch = self.state.epoch self.rh.ckpt_saver.save_if_best( save_epoch, self, # metric_dict=save_metrics, metric_dict=metrics_dict, # preds=save_preds, preds=preds, save_model=self.config.do_save) # todo(wrong): support max/min metrics def early_stopping_callback(self, metrics_dict: Dict): if not self.config.early_stopping: return False if self.config.early_stopping not in metrics_dict: log.warning( f'Early stopping but metric {self.config.early_stopping} not found' ) return False curr_metric = metrics_dict[self.config.early_stopping] if self.state.metric_best is not None: if self.state.metric_best < curr_metric: log.info( f"Early stopping: prev {self.state.metric_best}\t current: {curr_metric}" ) return True else: log.info( f"Not stopping: prev {self.state.metric_best}\t current: {curr_metric}" ) # otherwise store new best self.state.metric_best = curr_metric return False def make_ckpt_dict(self) -> CheckpointDict: self.model.cpu( ) # todo(parallel): verify this isn't necessary for save model_for_ckpt = self._model_for_ckpt() sched = None if self.scheduler is not None: sched = self.scheduler.state_dict() ckpt_dict: CheckpointDict = { # 'model_state': self.model.state_dict(), 'model_state': model_for_ckpt.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': sched, 'config': dict(self.config.items() ), # todo: fix this (so that it can be reloaded) 'step': self.state.step, 'epoch': self.state.epoch } # was needed when we did self.model.cpu() self.model.to(self.device) return ckpt_dict def _model_for_ckpt(self): #todo(parallel): verify don't need cpu if k_data_parallel: return self.model.module else: return self.model # model_for_save = self.model.cpu() # return model_for_save def eval(self): if k_data_parallel and self.model_is_parallelized: self.model = self.model.module self.model_is_parallelized = False # self.model.to(self.device) # todo(parallel): do we need this? self.model.eval() # put model in eval mode def train(self): if k_data_parallel and not self.model_is_parallelized: self.model = DataParallel(self.model, self.gpu_ids) self.model_is_parallelized = True #self.model.to(self.device) # todo(parallel): do we need this? self.model.train()
def main(): global args args = parser.parse_args() torch.manual_seed(0) # TODO: uncomment to use GPU for training. # torch.cuda.set_device(0) model = import_module(args.model) config, net, loss, get_pbb = model.get_model() start_epoch = args.start_epoch save_dir = args.save_dir if args.resume: checkpoint = torch.load(args.resume) if start_epoch == 0: start_epoch = checkpoint['epoch'] + 1 if not save_dir: save_dir = checkpoint['save_dir'] else: save_dir = os.path.join('results', save_dir) net.load_state_dict(checkpoint['state_dict']) else: if start_epoch == 0: start_epoch = 1 if not save_dir: exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime()) save_dir = os.path.join('results', args.model + '-' + exp_id) else: save_dir = os.path.join('results', save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) logfile = os.path.join(save_dir, 'log') if args.test != 1: sys.stdout = Logger(logfile) pyfiles = [f for f in os.listdir('./') if f.endswith('.py')] for f in pyfiles: shutil.copy(f, os.path.join(save_dir, f)) if 'none' not in args.gpu.lower() and torch.cuda.is_available(): print('Use GPU for training.') n_gpu = setgpu(args.gpu) args.n_gpu = n_gpu net = net.cuda() loss = loss.cuda() cudnn.benchmark = True net = DataParallel(net) else: print('Use CPU for training.') net = net.cpu() datadir = config_training['preprocess_result_path'] if args.test == 1: margin = 32 sidelen = 144 split_comber = SplitComb(sidelen, config['max_stride'], config['stride'], margin, config['pad_value']) dataset = data.DataBowl3Detector(datadir, 'full.npy', config, phase='test', split_comber=split_comber) test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers, collate_fn=data.collate, pin_memory=False) test(test_loader, net, get_pbb, save_dir, config) return #net = DataParallel(net) dataset = data.DataBowl3Detector(datadir, 'kaggleluna_full.npy', config, phase='train') print('batch_size:', args.batch_size) train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) dataset = data.DataBowl3Detector( datadir, 'kaggleluna_full.npy', # TODO: replace with valsplit.npy when full datasets available. config, phase='val') val_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=args.weight_decay) def get_lr(epoch): if epoch <= args.epochs * 0.5: lr = args.lr elif epoch <= args.epochs * 0.8: lr = 0.1 * args.lr else: lr = 0.01 * args.lr return lr for epoch in range(start_epoch, args.epochs + 1): train(train_loader, net, loss, epoch, optimizer, get_lr, args.save_freq, save_dir) validate(val_loader, net, loss)
def SPNNetTrain(context): torch.manual_seed(0) args = context.args saveFolder = os.path.dirname(args.outputCheckpoint) dataFolder = args.inputDataFolder checkoutPointPath = args.inputCheckpoint trainIds = args.inputTrainData[args.idColumn] validateIds = args.inputValidateData[args.idColumn] workers = asyncio.WORKERS epochs = args.epochs batchSize = args.batchSize learningRate = args.learningRate momentum = args.momentum weightDecay = args.weightDecay useGpu = torch.cuda.is_available() config, net, loss, getPbb = model.get_model() if checkoutPointPath: checkpoint = torch.load(checkoutPointPath) startEpoch = checkpoint["epoch"] + 1 net.load_state_dict(checkpoint["state_dict"]) else: startEpoch = 1 if useGpu: print("Use GPU {} for training.".format(torch.cuda.current_device())) net = net.cuda() loss = loss.cuda() cudnn.benchmark = True net = DataParallel(net) else: print("Use CPU for training.") net = net.cpu() # Train sets dataset = data.DataBowl3Detector(dataFolder, trainIds, config, phase="train") trainLoader = DataLoader( dataset, batch_size=batchSize, shuffle=True, num_workers=workers, pin_memory=True, ) # Validation sets dataset = data.DataBowl3Detector(dataFolder, validateIds, config, phase="val") valLoader = DataLoader( dataset, batch_size=batchSize, shuffle=False, num_workers=workers, pin_memory=True, ) optimizer = torch.optim.SGD(net.parameters(), learningRate, momentum=momentum, weight_decay=weightDecay) getlr = functools.partial(getLearningRate, epochs=epochs, lr=learningRate) for epoch in range(startEpoch, epochs + 1): train(trainLoader, net, loss, epoch, optimizer, getlr, saveFolder) validate(valLoader, net, loss) ckptPath = os.path.join(saveFolder, "model.ckpt") save(ckptPath, net, epochs=epochs) return ckptPath