class Worker(): def __init__(self, fold, conf, data_conf, cache_manager, args, inference=False, verbose=True): self._args = args self._fold = fold self._conf = conf self._data_conf = data_conf self._inference = inference self._verbose = verbose self.tmp_dir = self._data_conf['tmp'] # we save output with this folder structure: # output/ # -> tensorboard/ (tensorboard results) # -> results/ (output files: images, illuminant, GT, etc...) # -> checkpoint.pth.tar (checkpoint to continue training in case of failure) # -> model_best.pth.tar (best checkpoint, for inference) self._pretrained_model = None if not self._inference: output_dir = os.path.join(self._args.outputfolder, str(self._fold)) self._tensorboard_dir = os.path.join(output_dir, 'tensorboard') self._results_dir = os.path.join(output_dir, 'results') self._best_checkpoint_file = os.path.join(output_dir, 'model_best.pth.tar') self._checkpoint_file = os.path.join(output_dir, 'checkpoint.pth.tar') self._pretrained_model = self._args.pretrainedmodel # create all directories os.makedirs(self._tensorboard_dir, exist_ok=True) else: # for inference all results are saved under the output directory # (images, illuminant, GT, etc...) self._results_dir = self._args.outputfolder if isinstance(self._args.checkpointfile, list): self._checkpoint_file = self._args.checkpointfile[fold] else: self._checkpoint_file = self._args.checkpointfile self._display = Display(self._conf) self._factory = Factory(self._conf, self._data_conf, cache_manager, self._args, verbose) self._cache_manager = cache_manager # create output directory os.makedirs(self._results_dir, exist_ok=True) os.environ['TORCH_HOME'] = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'torch_model_zoo') # function used to determine the best epoch def _compute_best(self, best, train_stats, val_stats): metric = train_stats.mean_loss if 'choose_best_epoch_by' in self._conf: if self._conf['choose_best_epoch_by'] == 'mean_angular_error': metric = train_stats.mean_err elif self._conf['choose_best_epoch_by'] == 'median_angular_error': metric = train_stats.med_err elif self._conf['choose_best_epoch_by'] == 'mean_loss': metric = train_stats.mean_loss elif self._conf[ 'choose_best_epoch_by'] == 'val_median_angular_error': metric = val_stats.med_err else: raise Exception('Invalid "choose_best_epoch_by" option') is_best = metric < best best = min(metric, best) return is_best, best # function to print the epoch info def _log_epoch(self, epoch, train_stats, val_stats): if self._verbose and epoch % self._conf['print_frequency_epoch'] == 0: print( 'Epoch [{}]: AE (mean={:.4f} med={:.4f}) loss {:.4f} time={:.1f}' .format(epoch, train_stats.mean_err, train_stats.med_err, train_stats.mean_loss, train_stats.time), end='') if val_stats is not None: print( ' (val: AE (mean={:.4f} med={:.4f}) loss={:.4f} time={:.4f})\t' .format(val_stats.mean_err, val_stats.med_err, val_stats.mean_loss, val_stats.time), end='') print() # 1. Log scalar values (scalar summary) info = { 'Epoch Loss': train_stats.mean_loss, 'Epoch Mean AE': train_stats.mean_err, 'Epoch Median AE': train_stats.med_err } if val_stats is not None: info.update({ 'Epoch Loss (validation)': val_stats.mean_loss, 'Epoch Mean AE (validation)': val_stats.mean_err, 'Epoch Median AE (validation)': val_stats.med_err }) for tag, value in info.items(): self.logger.scalar_summary(tag, value, epoch) def run(self): args = self._args gpu = args.gpu self._conf['use_gpu'] = gpu is not None if self._verbose: if gpu is not None: print("Using GPU: {}".format(gpu)) else: print( "WARNING: You're training on the CPU, this could be slow!") # create transforms transforms = create_all_transforms(self, self._conf['transforms']) # copy FFCC histogram settings to conf (from transform settings) self._conf['log_uv_warp_histogram'] = find_loguv_warp_conf(transforms) # create model self.model = self._factory.get_model() # if we're evaluating instead of training: # 1. init the model (without training illuminants) # 2. load model weights if args.evaluate: self.model.initialize() if self._inference: checkpoint = self._checkpoint_file else: checkpoint = self._best_checkpoint_file # optionally resume from a checkpoint start_epoch, best, self.model = self._factory.resume_from_checkpoint( checkpoint, self.model, None, gpu) else: checkpoint = self._checkpoint_file # create validation/test transforms if defined, otherwise, the same as training if self._conf['transforms_valtest'] is not None: transforms_valtest = create_all_transforms( self, self._conf['transforms_valtest']) else: transforms_valtest = transforms if gpu is not None: torch.cuda.set_device(gpu) cudnn.benchmark = True if args.testfile is not None: # test loader test_dataset, test_loader, test_loader_cache = self._factory.get_loader( args.testfile, transforms_valtest, gpu) # if evaluating, copy model to GPU, evaluate and die if args.evaluate: if gpu is not None: self.model = self.model.cuda(gpu) return self.validate(test_loader) # we finish here! # if validation file is defined if args.valfile is not None: # to save memory, don't do it again if valfile==testfile if args.valfile == args.testfile: val_dataset = test_dataset val_loader = test_loader val_loader_cache = test_loader_cache else: # validation loader val_dataset, val_loader, val_loader_cache = self._factory.get_loader( args.valfile, transforms_valtest, gpu) # training loader train_dataset, train_loader, train_loader_cache = self._factory.get_loader( args.trainfiles, transforms, gpu, train=True) # init model with the training set illuminants self.model.initialize(train_dataset.get_illuminants_by_sensor()) # optionally pretrain model self._factory.pretrain_model(self._pretrained_model, self.model) # optionally resume from a checkpoint self.optimizer, optimizer_name = self._factory.get_optimizer( self.model) start_epoch, best, self.model = self._factory.resume_from_checkpoint( checkpoint, self.model, self.optimizer, gpu) # define loss function self.criterion = self._factory.get_criterion() # tensorboard logger self.logger = TensorBoardLogger(self._tensorboard_dir) # learning rate scheduler (if defined) scheduler, scheduler_name = self._factory.get_lr_scheduler( start_epoch, self.optimizer) # copy stuff to GPU if gpu is not None: self.criterion = self.criterion.cuda(gpu) self.model = self.model.cuda(gpu) # for FFCC, we reset the optimizer after some epochs # because they use two loss functions, ugly trick # TODO: fix reset_opt = -1 if 'reset_optimizer_epoch' in self._conf: reset_opt = self._conf['reset_optimizer_epoch'] # load data for the first time # we use the cache loaders, they define batch size=1 # so that we can see the progress with tqdm if self._cache_manager.transforms().length > 0 and self._fold == 0: if self._verbose: print('Caching images...') for data in tqdm(train_loader_cache, desc="Training set", disable=not self._verbose): pass if args.testfile is not None: for data in tqdm(test_loader_cache, desc="Test set", disable=not self._verbose): pass if args.valfile is not None and args.testfile != args.valfile: for data in tqdm(val_loader_cache, desc="Validation set", disable=not self._verbose): pass # if epochs==0, we don't really want to train, # we only want to do the candidate selection process for our method if self._conf['epochs'] == 0: print('WARNING: Training 0 epochs') checkpoint = { 'epoch': 0, 'arch': self._conf['network']['arch'], 'subarch': self._conf['network']['subarch'], 'state_dict': self.model.state_dict(), 'best': float("inf"), 'optimizer': self.optimizer.state_dict() } self._factory.save_checkpoint(self._checkpoint_file, self._best_checkpoint_file, checkpoint, is_best=True) # epoch loop for epoch in range(start_epoch, self._conf['epochs']): # ugly trick for FFCC 2 losses if epoch == reset_opt: if self._verbose: print('Reset optimizer and lr scheduler') best = float("inf") self.optimizer, optimizer_name = self._factory.get_optimizer( self.model) # TODO: What if lr scheduler changes its internal API? if scheduler is not None: scheduler.optimizer = self.optimizer # train for one epoch train_stats = self.train(train_loader, epoch) # validation val_stats = None if args.valfile is not None: _, val_stats = self.validate(val_loader, epoch) # compute the best training epoch is_best, best = self._compute_best(best, train_stats, val_stats) # log epoch details self._log_epoch(epoch, train_stats, val_stats) # learning rate scheduler if scheduler is not None: # TODO: hardcoded if scheduler_name == 'ReduceLROnPlateau': scheduler.step(train_stats.mean_err) else: scheduler.step() # save checkpoint! checkpoint = { 'epoch': epoch + 1, 'arch': self._conf['network']['arch'], 'subarch': self._conf['network']['subarch'], 'state_dict': self.model.state_dict(), 'best': best, 'optimizer': self.optimizer.state_dict() } self._factory.save_checkpoint(self._checkpoint_file, self._best_checkpoint_file, checkpoint, is_best) # get results for the best model start_epoch, best, self.model = self._factory.load_model( self._best_checkpoint_file, self.model, self.optimizer, gpu) # return results from best epoch if args.testfile is not None: start_time = time.time() results = self.validate(test_loader) if self._verbose: print( 'Final inference (including generation of output files) took {:.4f}' .format(time.time() - start_time)) return results else: # for some datasets, we have no validation ground truth, # so, no evaluation possible return [], EpochStats(-1, -1, -1, 0) # log iteration def _log_iteration(self, epoch, step, len_epoch, loss, err, data, output): real_step = epoch * len_epoch + step if self._conf['tensorboard_frequency'] != -1 and real_step % self._conf[ 'tensorboard_frequency'] == 0: # Log scalar values (scalar summary) info = {'Loss': loss, 'Angular Error': err} for tag, value in info.items(): self.logger.scalar_summary(tag, value, real_step) # Log values and gradients of the parameters (histogram summary) for tag, value in self.model.named_parameters(): tag = tag.replace('.', '/') if value.requires_grad: if value.grad is None: print('WARNING: variable ', tag, '.grad is None!') else: self.logger.histo_summary(tag, value.data.cpu().numpy(), real_step) self.logger.histo_summary( tag + '/grad', value.grad.data.cpu().numpy(), real_step) if 'confidence' in output: self.logger.histo_summary( 'confidence', output['confidence'].data.cpu().numpy().flatten(), real_step) if self._conf[ 'tensorboard_frequency_im'] != -1 and real_step % self._conf[ 'tensorboard_frequency_im'] == 0: # Log training images (image summary) info = self._display.get_images(data, output) for tag, images in info.items(): self.logger.image_summary(tag, images, real_step) def train(self, train_loader, epoch): start_t = time.time() # log starting time self.model.train() # switch to train mode # angular errors and loss lists angular_errors = [] loss_vec = [] # batch loop for step, data in enumerate(train_loader): data['epoch'] = epoch # we know what's the current epoch err = err_m = output = loss = None def closure(): nonlocal err, err_m, output, loss self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, data, self.model) loss.backward() err_m = angular_error_degrees( output['illuminant'], Variable(data['illuminant'], requires_grad=False)).detach() err = err_m.sum().item() / err_m.shape[0] return loss self.optimizer.step(closure) angular_errors += err_m.cpu().data.tolist() loss_value = loss.detach().item() loss_vec.append(loss_value) self._log_iteration(epoch, step, len(train_loader), loss_value, err, data, output) angular_errors = np.array(angular_errors) mean_err = angular_errors.mean() med_err = np.median(angular_errors) mean_loss = np.array(loss_vec).mean() t = time.time() - start_t return EpochStats(mean_err, med_err, mean_loss, t) def validate(self, val_loader, epoch=None): with torch.no_grad(): # don't compute gradients save_full_res = self._args.save_fullres training = epoch is not None start_t = time.time() # switch to evaluate mode self.model.eval() res = [] angular_errors = [] loss_vec = [] for i, data in enumerate(val_loader): if training: data['epoch'] = epoch # compute output output = self.model(data) # measure accuracy and save loss err = None if 'illuminant' in data: if training: loss = self.criterion(output, data, self.model) loss_vec.append(loss.detach().item()) err = angular_error_degrees( output['illuminant'], Variable(data['illuminant'], requires_grad=False)).data.cpu().tolist() angular_errors += err # When training, we don't want to save validation images if not training: res += self._display.save_output(data, output, err, val_loader.dataset, self._results_dir, save_full_res) # some datasets have no validation GT mean_err = med_err = mean_loss = -1 if len(angular_errors) > 0: angular_errors = np.array(angular_errors) mean_err = angular_errors.mean() med_err = np.median(angular_errors) if len(loss_vec) > 0: mean_loss = np.array(loss_vec).mean() t = time.time() - start_t return res, EpochStats(mean_err, med_err, mean_loss, t)