def to_device(self, *params, force_list=False): if is_distributed(): device = torch.device('cuda:{}'.format(get_rank())) else: device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda') return_list = list() for i in range(len(params)): return_list.append(params[i].to(device)) if force_list: return return_list else: return return_list[0] if len(params) == 1 else return_list
def save_net(self, net, save_mode='iters'): if is_distributed() and get_rank() != 0: return state = { 'config_dict': self.configer.to_dict(), 'state_dict': net.state_dict(), } if self.configer.get('checkpoints', 'checkpoints_root') is None: checkpoints_dir = os.path.join(self.configer.get('project_dir'), self.configer.get('checkpoints', 'checkpoints_dir')) else: checkpoints_dir = os.path.join(self.configer.get('checkpoints', 'checkpoints_root'), self.configer.get('checkpoints', 'checkpoints_dir')) if not os.path.exists(checkpoints_dir): os.makedirs(checkpoints_dir) latest_name = '{}_latest.pth'.format(self.configer.get('checkpoints', 'checkpoints_name')) torch.save(state, os.path.join(checkpoints_dir, latest_name)) if save_mode == 'performance': if self.configer.get('performance') > self.configer.get('max_performance'): latest_name = '{}_max_performance.pth'.format(self.configer.get('checkpoints', 'checkpoints_name')) torch.save(state, os.path.join(checkpoints_dir, latest_name)) self.configer.update(['max_performance'], self.configer.get('performance')) elif save_mode == 'val_loss': if self.configer.get('val_loss') < self.configer.get('min_val_loss'): latest_name = '{}_min_loss.pth'.format(self.configer.get('checkpoints', 'checkpoints_name')) torch.save(state, os.path.join(checkpoints_dir, latest_name)) self.configer.update(['min_val_loss'], self.configer.get('val_loss')) elif save_mode == 'iters': if self.configer.get('iters') - self.configer.get('last_iters') >= \ self.configer.get('checkpoints', 'save_iters'): latest_name = '{}_iters{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'), self.configer.get('iters')) torch.save(state, os.path.join(checkpoints_dir, latest_name)) self.configer.update(['last_iters'], self.configer.get('iters')) elif save_mode == 'epoch': if self.configer.get('epoch') - self.configer.get('last_epoch') >= \ self.configer.get('checkpoints', 'save_epoch'): latest_name = '{}_epoch{}.pth'.format(self.configer.get('checkpoints', 'checkpoints_name'), self.configer.get('epoch')) torch.save(state, os.path.join(checkpoints_dir, latest_name)) self.configer.update(['last_epoch'], self.configer.get('epoch')) else: Log.error('Metric: {} is invalid.'.format(save_mode)) exit(1)
def _make_parallel(self, net): if is_distributed(): local_rank = get_rank() return torch.nn.parallel.DistributedDataParallel( net, device_ids=[local_rank], output_device=local_rank, ) if len(self.configer.get('gpu')) == 1: self.configer.update(['network', 'gathered'], True) return DataParallelModel(net, gather_=self.configer.get('network', 'gathered'))
def __val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() self.pixel_loss.eval() start_time = time.time() replicas = self.evaluator.prepare_validaton() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): if j % 10 == 0: Log.info('{} images processed\n'.format(j)) if self.configer.get('dataset') == 'lip': (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True) else: (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) with torch.no_grad(): if self.configer.get('dataset') == 'lip': inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0) outputs = self.seg_net(inputs) outputs_ = self.module_runner.gather(outputs) if isinstance(outputs_, (list, tuple)): outputs_ = outputs_[-1] outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone() outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone() if outputs_rev.shape[1] == 20: outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:] outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:] outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:] outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:] outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:] outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:] outputs_rev = torch.flip(outputs_rev, [3]) outputs = (outputs + outputs_rev) / 2. self.evaluator.update_score(outputs, data_dict['meta']) elif self.data_helper.conditions.diverse_size: outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs) for i in range(len(outputs)): loss = self.pixel_loss(outputs[i], targets[i]) self.val_losses.update(loss.item(), 1) outputs_i = outputs[i] if isinstance(outputs_i, torch.Tensor): outputs_i = [outputs_i] self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1]) else: outputs = self.seg_net(*inputs) try: loss = self.pixel_loss( outputs, targets, gathered=self.configer.get('network', 'gathered') ) except AssertionError as e: print(len(outputs), len(targets)) if not is_distributed(): outputs = self.module_runner.gather(outputs) self.val_losses.update(loss.item(), batch_size) self.evaluator.update_score(outputs, data_dict['meta']) self.batch_time.update(time.time() - start_time) start_time = time.time() self.evaluator.update_performance() self.configer.update(['val_loss'], self.val_losses.avg) self.module_runner.save_net(self.seg_net, save_mode='performance') self.module_runner.save_net(self.seg_net, save_mode='val_loss') cudnn.benchmark = True # Print the log info & reset the states. if not is_distributed() or get_rank() == 0: Log.info( 'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss {loss.avg:.8f}\n'.format( batch_time=self.batch_time, loss=self.val_losses)) self.evaluator.print_scores() self.batch_time.reset() self.val_losses.reset() self.evaluator.reset() self.seg_net.train() self.pixel_loss.train()
def __train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() self.pixel_loss.train() start_time = time.time() if "swa" in self.configer.get('lr', 'lr_policy'): normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75) swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1 if hasattr(self.train_loader.sampler, 'set_epoch'): self.train_loader.sampler.set_epoch(self.configer.get('epoch')) for i, data_dict in enumerate(self.train_loader): if self.configer.get('lr', 'metric') == 'iters': self.scheduler.step(self.configer.get('iters')) else: self.scheduler.step(self.configer.get('epoch')) if self.configer.get('lr', 'is_warm'): self.module_runner.warm_lr( self.configer.get('iters'), self.scheduler, self.optimizer, backbone_list=[0,] ) (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) self.data_time.update(time.time() - start_time) foward_start_time = time.time() outputs = self.seg_net(*inputs) self.foward_time.update(time.time() - foward_start_time) loss_start_time = time.time() if is_distributed(): import torch.distributed as dist def reduce_tensor(inp): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = get_world_size() if world_size < 2: return inp with torch.no_grad(): reduced_inp = inp dist.reduce(reduced_inp, dst=0) return reduced_inp loss = self.pixel_loss(outputs, targets) backward_loss = loss display_loss = reduce_tensor(backward_loss) / get_world_size() else: backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered')) self.train_losses.update(display_loss.item(), batch_size) self.loss_time.update(time.time() - loss_start_time) backward_start_time = time.time() self.optimizer.zero_grad() backward_loss.backward() self.optimizer.step() self.backward_time.update(time.time() - backward_start_time) # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \ (not is_distributed() or get_rank() == 0): Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t' 'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t' 'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time, foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.foward_time.reset() self.backward_time.reset() self.loss_time.reset() self.data_time.reset() self.train_losses.reset() # save checkpoints for swa if 'swa' in self.configer.get('lr', 'lr_policy') and \ self.configer.get('iters') > normal_max_iters and \ ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \ self.configer.get('iters') == self.configer.get('solver', 'max_iters')): self.optimizer.update_swa() if self.configer.get('iters') == self.configer.get('solver', 'max_iters'): break # Check to val the current model. # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0: if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0: self.__val() self.configer.plus_one('epoch')