class AutoSKUMerger(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.cls_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self.runner_state = None self.round = 1 self._relabel() def _relabel(self): label_id = 0 label_dict = dict() old_label_path = self.configer.get('data', 'label_path') new_label_path = '{}_new'.format(self.configer.get('data', 'label_path')) self.configer.update('data.label_path', new_label_path) fw = open(new_label_path, 'w') check_valid_dict = dict() with open(old_label_path, 'r') as fr: for line in fr.readlines(): line_items = line.strip().split() if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])): continue if line_items[1] not in label_dict: label_dict[line_items[1]] = label_id label_id += 1 if line_items[0] in check_valid_dict: Log.error('Duplicate Error: {}'.format(line_items[0])) exit() check_valid_dict[line_items[0]] = 1 fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]])) fw.close() shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt')) self.configer.update(('data.num_classes'), [label_id]) Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes'))) def _init(self): self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.cls_model_manager = ModelManager(self.configer) self.cls_data_loader = DataLoader(self.configer) self.cls_running_score = RunningScore(self.configer) self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.cls_net = self.cls_model_manager.get_model() self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.solver_dict = self.configer.get(self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net, self.optimizer = RunnerHelper.to_dtype(self, self.cls_net, self.optimizer) self.train_loader = self.cls_data_loader.get_trainloader() self.val_loader = self.cls_data_loader.get_valloader() self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{'params': lr_1, 'lr': self.solver_dict['lr']['base_lr']*self.configer.get('network', 'bb_lr_scale')}, {'params': lr_2, 'lr': self.solver_dict['lr']['base_lr']}] return params def _merge_class(self, cmatrix, fscore_list): Log.info('Merging class...') Log.info('Avg F1-score: {}'.format(fscore_list[-1])) threshold = max(self.configer.get('merge', 'min_thre'), self.configer.get('merge', 'max_thre') - self.configer.get('merge', 'round_decay') * self.round) h, w = cmatrix.shape[0], cmatrix.shape[1] per_class_num = np.sum(cmatrix, 1) pairs_list = list() pre_dict = dict() for i in range(h): for j in range(w): if i == j: continue if cmatrix[i][j] * 1.0 / per_class_num[i] > threshold: pairs_list.append([i, j]) pre_dict[i] = i pre_dict[j] = j for pair in pairs_list: root_node = list() for item in pair: r = item while pre_dict[r] != r: r = pre_dict[r] i = item while i != r: j = pre_dict[i] pre_dict[i] = r i = j root_node.append(r) if root_node[0] != root_node[1]: pre_dict[root_node[0]] = root_node[1] pairs_dict = dict() for k in pre_dict.keys(): v = k while pre_dict[v] != v: v = pre_dict[v] if v != k: if v not in pairs_dict: pairs_dict[v] = [k] else: pairs_dict[v].append(k) mutual_pairs_dict = {} for k, v in pairs_dict.items(): mutual_pairs_dict[k] = v if len(v) > 1: # multi relation for p in v: mutual_pairs_dict[p] = [k] for q in v: if p != q: mutual_pairs_dict[p].append(q) else: mutual_pairs_dict[v[0]] = [k] # mutual relation id_map_list = [-1] * self.configer.get('data', 'num_classes')[0] label_cnt = 0 for i in range(self.configer.get('data', 'num_classes')[0]): if id_map_list[i] != -1: continue power = self.round / self.configer.get('merge', 'max_round') if self.configer.get('merge', 'enable_fscore') and \ fscore_list[i] / fscore_list[-1] < self.configer.get('merge', 'fscore_ratio') * power: continue id_map_list[i] = label_cnt if i in mutual_pairs_dict: for v in mutual_pairs_dict[i]: assert id_map_list[v] == -1 id_map_list[v] = label_cnt label_cnt += 1 fw = open('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), 'w') with open(self.configer.get('data', 'label_path'), 'r') as fr: for line in fr.readlines(): path, label = line.strip().split() if id_map_list[int(label)] == -1: continue map_label = id_map_list[int(label)] fw.write('{} {}\n'.format(path, map_label)) fw.close() shutil.move('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), self.configer.get('data', 'label_path')) shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'label_{}.txt'.format(self.round))) old_label_cnt = self.configer.get('data', 'num_classes')[0] self.configer.update('data.num_classes', [label_cnt]) return old_label_cnt - label_cnt def run(self): last_acc = 0.0 while self.round <= self.configer.get('merge', 'max_round'): Log.info('Merge Round: {}'.format(self.round)) Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes'))) self._init() self.train() acc, cmatrix, fscore_list = self.val(self.cls_data_loader.get_valloader()) merge_cnt = self._merge_class(cmatrix, fscore_list) if merge_cnt < self.configer.get('merge', 'cnt_thre') \ or (acc - last_acc) < self.configer.get('merge', 'acc_thre'): break last_acc = acc self.round += 1 shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'merge_label.txt')) self._init() self.train() Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes'))) def train(self): """ Train function of every epoch during train phase. """ self.cls_net.train() start_time = time.time() while self.runner_state['iters'] < self.solver_dict['max_iters']: # Adjust the learning rate after every epoch. self.runner_state['epoch'] += 1 for i, data_dict in enumerate(self.train_loader): Trainer.update(self, solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) # Change the data type. # Forward pass. out = self.cls_net(data_dict) # Compute the loss of the train batch & backward. loss_dict = self.loss(out) loss = loss_dict['loss'] self.train_losses.update(loss.item(), data_dict['img'].size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict['display_iter'] == 0: Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr']['metric'] == 'iters' and self.runner_state['iters'] == self.solver_dict['max_iters']: self.val() break # Check to val the current model. if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0: self.val() def val(self, loader=None): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() loader = self.val_loader if loader is None else loader list_y_true, list_y_pred = [], [] with torch.no_grad(): for j, data_dict in enumerate(loader): out = self.cls_net(data_dict) loss_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) # Compute the loss of the val batch. self.cls_running_score.update(out_dict, label_dict) y_true = label_dict['out0'].view(-1).cpu().numpy().tolist() y_pred = out_dict['out0'].max(1)[1].view(-1).cpu().numpy().tolist() list_y_true.extend(y_true) list_y_pred.extend(y_pred) self.val_losses.update(loss_dict['loss'].mean().item(), data_dict['img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() RunnerHelper.save_net(self, self.cls_net, performance=self.cls_running_score.top1_acc.avg['out0']) self.runner_state['performance'] = self.cls_running_score.top1_acc.avg['out0'] # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time)) Log.info('Test Set: {} images'.format(len(list_y_true))) Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses)) Log.info('Top1 ACC = {}'.format(self.cls_running_score.top1_acc.avg['out0'])) # Log.info('Top5 ACC = {}'.format(self.cls_running_score.get_top5_acc())) acc= self.cls_running_score.top1_acc.avg['out0'] cmatrix = confusion_matrix(list_y_true, list_y_pred) fscore_str = classification_report(list_y_true, list_y_pred, digits=5) fscore_list = [float(line.strip().split()[-2]) for line in fscore_str.split('\n')[2:] if len(line.strip().split()) > 0] self.batch_time.reset() self.val_losses.reset() self.cls_running_score.reset() self.cls_net.train() return acc, cmatrix, fscore_list
class Trainer(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.foward_time = AverageMeter() self.backward_time = AverageMeter() self.loss_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.data_loader = DataLoader(configer) self.optim_scheduler = OptimScheduler(configer) self.data_helper = DataHelper(configer, self) self.evaluator = get_evaluator(configer, self) self.seg_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self.running_score = None self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method'))) if self.configer.get('optim', 'group_method') == 'decay': params_group = self.group_weight(self.seg_net) else: assert self.configer.get('optim', 'group_method') is None params_group = self._get_parameters() self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group) self.train_loader = self.data_loader.get_trainloader() self.val_loader = self.data_loader.get_valloader() self.pixel_loss = self.loss_manager.get_seg_loss() if is_distributed(): self.pixel_loss = self.module_runner.to_device(self.pixel_loss) @staticmethod def group_weight(module): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) else: if hasattr(m, 'weight'): group_no_decay.append(m.weight) if hasattr(m, 'bias'): group_no_decay.append(m.bias) assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] return groups def _get_parameters(self): bb_lr = [] nbb_lr = [] params_dict = dict(self.seg_net.named_parameters()) for key, value in params_dict.items(): if 'backbone' not in key: nbb_lr.append(value) else: bb_lr.append(value) params = [{'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr')}, {'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult')}] return params def __train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() self.pixel_loss.train() start_time = time.time() if "swa" in self.configer.get('lr', 'lr_policy'): normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75) swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1 if hasattr(self.train_loader.sampler, 'set_epoch'): self.train_loader.sampler.set_epoch(self.configer.get('epoch')) for i, data_dict in enumerate(self.train_loader): if self.configer.get('lr', 'metric') == 'iters': self.scheduler.step(self.configer.get('iters')) else: self.scheduler.step(self.configer.get('epoch')) if self.configer.get('lr', 'is_warm'): self.module_runner.warm_lr( self.configer.get('iters'), self.scheduler, self.optimizer, backbone_list=[0,] ) (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) self.data_time.update(time.time() - start_time) foward_start_time = time.time() outputs = self.seg_net(*inputs) self.foward_time.update(time.time() - foward_start_time) loss_start_time = time.time() if is_distributed(): import torch.distributed as dist def reduce_tensor(inp): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = get_world_size() if world_size < 2: return inp with torch.no_grad(): reduced_inp = inp dist.reduce(reduced_inp, dst=0) return reduced_inp loss = self.pixel_loss(outputs, targets) backward_loss = loss display_loss = reduce_tensor(backward_loss) / get_world_size() else: backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered')) self.train_losses.update(display_loss.item(), batch_size) self.loss_time.update(time.time() - loss_start_time) backward_start_time = time.time() self.optimizer.zero_grad() backward_loss.backward() self.optimizer.step() self.backward_time.update(time.time() - backward_start_time) # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \ (not is_distributed() or get_rank() == 0): Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t' 'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t' 'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time, foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.foward_time.reset() self.backward_time.reset() self.loss_time.reset() self.data_time.reset() self.train_losses.reset() # save checkpoints for swa if 'swa' in self.configer.get('lr', 'lr_policy') and \ self.configer.get('iters') > normal_max_iters and \ ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \ self.configer.get('iters') == self.configer.get('solver', 'max_iters')): self.optimizer.update_swa() if self.configer.get('iters') == self.configer.get('solver', 'max_iters'): break # Check to val the current model. # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0: if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0: self.__val() self.configer.plus_one('epoch') def __val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() self.pixel_loss.eval() start_time = time.time() replicas = self.evaluator.prepare_validaton() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): if j % 10 == 0: Log.info('{} images processed\n'.format(j)) if self.configer.get('dataset') == 'lip': (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True) else: (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) with torch.no_grad(): if self.configer.get('dataset') == 'lip': inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0) outputs = self.seg_net(inputs) outputs_ = self.module_runner.gather(outputs) if isinstance(outputs_, (list, tuple)): outputs_ = outputs_[-1] outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone() outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone() if outputs_rev.shape[1] == 20: outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:] outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:] outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:] outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:] outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:] outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:] outputs_rev = torch.flip(outputs_rev, [3]) outputs = (outputs + outputs_rev) / 2. self.evaluator.update_score(outputs, data_dict['meta']) elif self.data_helper.conditions.diverse_size: outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs) for i in range(len(outputs)): loss = self.pixel_loss(outputs[i], targets[i]) self.val_losses.update(loss.item(), 1) outputs_i = outputs[i] if isinstance(outputs_i, torch.Tensor): outputs_i = [outputs_i] self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1]) else: outputs = self.seg_net(*inputs) try: loss = self.pixel_loss( outputs, targets, gathered=self.configer.get('network', 'gathered') ) except AssertionError as e: print(len(outputs), len(targets)) if not is_distributed(): outputs = self.module_runner.gather(outputs) self.val_losses.update(loss.item(), batch_size) self.evaluator.update_score(outputs, data_dict['meta']) self.batch_time.update(time.time() - start_time) start_time = time.time() self.evaluator.update_performance() self.configer.update(['val_loss'], self.val_losses.avg) self.module_runner.save_net(self.seg_net, save_mode='performance') self.module_runner.save_net(self.seg_net, save_mode='val_loss') cudnn.benchmark = True # Print the log info & reset the states. if not is_distributed() or get_rank() == 0: Log.info( 'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss {loss.avg:.8f}\n'.format( batch_time=self.batch_time, loss=self.val_losses)) self.evaluator.print_scores() self.batch_time.reset() self.val_losses.reset() self.evaluator.reset() self.seg_net.train() self.pixel_loss.train() def train(self): # cudnn.benchmark = True # self.__val() if self.configer.get('network', 'resume') is not None: if self.configer.get('network', 'resume_val'): self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) return elif self.configer.get('network', 'resume_train'): self.__val(data_loader=self.data_loader.get_valloader(dataset='train')) return # return if self.configer.get('network', 'resume') is not None and self.configer.get('network', 'resume_val'): self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) return while self.configer.get('iters') < self.configer.get('solver', 'max_iters'): self.__train() # use swa to average the model if 'swa' in self.configer.get('lr', 'lr_policy'): self.optimizer.swap_swa_sgd() self.optimizer.bn_update(self.train_loader, self.seg_net) self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) def summary(self): from lib.utils.summary import get_model_summary import torch.nn.functional as F self.seg_net.eval() for j, data_dict in enumerate(self.train_loader): print(get_model_summary(self.seg_net, data_dict['img'][0:1])) return
class ImageClassifier(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = DictAverageMeter() self.val_losses = DictAverageMeter() self.cls_model_manager = ModelManager(configer) self.cls_data_loader = DataLoader(configer) self.running_score = RunningScore(configer) self.cls_net = self.cls_model_manager.get_model() self.solver_dict = self.configer.get( self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.cls_net, self.optimizer = RunnerHelper.to_dtype( self, self.cls_net, self.optimizer) self.train_loader = self.cls_data_loader.get_trainloader() self.val_loader = self.cls_data_loader.get_valloader() self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{ 'params': lr_1, 'lr': self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale') }, { 'params': lr_2, 'lr': self.solver_dict['lr']['base_lr'] }] return params def run(self): if self.configer.get('network', 'resume_val'): self.val() while self.runner_state['iters'] < self.solver_dict['max_iters']: self.train() self.val() def train(self): """ Train function of every epoch during train phase. """ self.cls_net.train() start_time = time.time() # Adjust the learning rate after every epoch. self.runner_state['epoch'] += 1 for i, data_dict in enumerate(self.train_loader): data_dict = {'src0_{}'.format(k): v for k, v in data_dict.items()} Trainer.update( self, warm_list=(0, ), warm_lr_list=(self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale'), ), solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) data_dict = RunnerHelper.to_device(self, data_dict) # Forward pass. out = self.cls_net(data_dict) loss_dict, _ = self.loss(out) # Compute the loss of the train batch & backward. loss = loss_dict['loss'] self.train_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) self.optimizer.zero_grad() loss.backward() if self.configer.get('network', 'clip_grad'): RunnerHelper.clip_grad(self.cls_net, 10.) self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict[ 'display_iter'] == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {4}\tLoss = {3}\n'.format( self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], self.train_losses.info(), RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr'][ 'metric'] == 'iters' and self.runner_state[ 'iters'] == self.solver_dict['max_iters']: break if self.runner_state['iters'] % self.solver_dict[ 'save_iters'] == 0 and self.configer.get( 'local_rank') == 0: RunnerHelper.save_net(self, self.cls_net) # Check to val the current model. # if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0 \ # and not self.configer.get('distributed'): # self.val() if self.runner_state['iters'] % self.solver_dict[ 'test_interval'] == 0: self.val() def val(self): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() with torch.no_grad(): for j, data_dict in enumerate(self.val_loader): data_dict = { 'src0_{}'.format(k): v for k, v in data_dict.items() } # Forward pass. data_dict = RunnerHelper.to_device(self, data_dict) out = self.cls_net(data_dict) loss_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) self.running_score.update(out_dict, label_dict) self.val_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() # RunnerHelper.save_net(self, self.cls_net) # only local_rank=0 can save net # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) Log.info('TestLoss = {}'.format(self.val_losses.info())) Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc())) Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc())) Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc())) self.batch_time.reset() self.batch_time.reset() self.val_losses.reset() self.running_score.reset() self.cls_net.train()
class MultiTaskClassifier(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = DictAverageMeter() self.val_losses = DictAverageMeter() self.cls_model_manager = ModelManager(configer) self.cls_data_loader = DataLoader(configer) self.running_score = RunningScore(configer) self.cls_net = self.cls_model_manager.get_model() self.solver_dict = self.configer.get( self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.cls_net, self.optimizer = RunnerHelper.to_dtype( self, self.cls_net, self.optimizer) self.train_loaders = dict() self.val_loaders = dict() for source in range(self.configer.get('data', 'num_data_sources')): self.train_loaders[source] = self.cls_data_loader.get_trainloader( source=source) self.val_loaders[source] = self.cls_data_loader.get_valloader( source=source) if self.configer.get('data', 'mixup'): assert (self.configer.get('data', 'num_data_sources') == 2 ), "mixup only support src0 and src1 load the same dataset" self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{ 'params': lr_1, 'lr': self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale') }, { 'params': lr_2, 'lr': self.solver_dict['lr']['base_lr'] }] return params def run(self): """ Train function of every epoch during train phase. """ if self.configer.get('network', 'resume_val'): self.val() self.cls_net.train() train_loaders = dict() for source in self.train_loaders: train_loaders[source] = iter(self.train_loaders[source]) start_time = time.time() # Adjust the learning rate after every epoch. while self.runner_state['iters'] < self.solver_dict['max_iters']: data_dict = dict() for source in train_loaders: try: tmp_data_dict = next(train_loaders[source]) # Log.info('iter={}, source={}'.format(self.runner_state['iters'], source)) except StopIteration: if source == 0 or source == '0': self.runner_state['epoch'] += 1 # Log.info('Repeat: iter={}, source={}'.format(self.runner_state['iters'], source)) train_loaders[source] = iter(self.train_loaders[source]) tmp_data_dict = next(train_loaders[source]) for k, v in tmp_data_dict.items(): data_dict['src{}_{}'.format(source, k)] = v if self.configer.get('data', 'multiscale') is not None: scale_ratios = self.configer.get('data', 'multiscale') scale_ratio = random.uniform(scale_ratios[0], scale_ratios[-1]) for key in data_dict: if key.endswith('_img'): data_dict[key] = F.interpolate( data_dict[key], scale_factor=[scale_ratio, scale_ratio], mode='bilinear', align_corners=True) if self.configer.get('data', 'mixup'): src0_resize = F.interpolate(data_dict['src0_img'], scale_factor=[ random.uniform(0.4, 0.6), random.uniform(0.4, 0.6) ], mode='bilinear', align_corners=True) b, c, h, w = src0_resize.shape pos = random.randint(0, 3) if pos == 0: # top-left data_dict['src1_img'][:, :, 0:h, 0:w] = src0_resize elif pos == 1: # top-right data_dict['src1_img'][:, :, 0:h, -w:] = src0_resize elif pos == 2: # bottom-left data_dict['src1_img'][:, :, -h:, 0:w] = src0_resize else: # bottom-right data_dict['src1_img'][:, :, -h:, -w:] = src0_resize data_dict = RunnerHelper.to_device(self, data_dict) Trainer.update( self, warm_list=(0, ), warm_lr_list=(self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale'), ), solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) # Forward pass. out = self.cls_net(data_dict) loss_dict, loss_weight_dict = self.loss(out) # Compute the loss of the train batch & backward. loss = loss_dict['loss'] self.train_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) self.optimizer.zero_grad() if self.configer.get('dtype') == 'fp16': with amp.scale_loss(loss, self.optimizer) as scaled_losses: scaled_losses.backward() else: loss.backward() if self.configer.get('network', 'clip_grad'): RunnerHelper.clip_grad(self.cls_net, 10.) self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict[ 'display_iter'] == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {5}\tLoss = {3}\nLossWeight = {4}\n'. format(self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], self.train_losses.info(), loss_weight_dict, RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr'][ 'metric'] == 'iters' and self.runner_state[ 'iters'] == self.solver_dict['max_iters']: if self.configer.get('local_rank') == 0: RunnerHelper.save_net(self, self.cls_net, postfix='final') break if self.runner_state['iters'] % self.solver_dict[ 'save_iters'] == 0 and self.configer.get( 'local_rank') == 0: RunnerHelper.save_net(self, self.cls_net) # Check to val the current model. if self.runner_state['iters'] % self.solver_dict[ 'test_interval'] == 0: self.val() if self.configer.get('local_rank') == 0: RunnerHelper.save_net( self, self.cls_net, performance=self.runner_state['performance']) self.val() def val(self): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() val_loaders = dict() val_to_end = dict() all_to_end = False for source in self.val_loaders: val_loaders[source] = iter(self.val_loaders[source]) val_to_end[source] = False with torch.no_grad(): while not all_to_end: data_dict = dict() for source in val_loaders: try: tmp_data_dict = next(val_loaders[source]) except StopIteration: val_to_end[source] = True val_loaders[source] = iter(self.val_loaders[source]) tmp_data_dict = next(val_loaders[source]) for k, v in tmp_data_dict.items(): data_dict['src{}_{}'.format(source, k)] = v # Forward pass. data_dict = RunnerHelper.to_device(self, data_dict) out = self.cls_net(data_dict) loss_dict, loss_weight_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) # Compute the loss of the val batch. self.running_score.update(out_dict, label_dict) self.val_losses.update( { key: loss.mean().item() for key, loss in loss_dict.items() }, data_dict['src0_img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() # check whether scan over all data sources all_to_end = True for source in val_to_end: if not val_to_end[source]: all_to_end = False Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) Log.info('TestLoss = {}'.format(self.val_losses.info())) Log.info('TestLossWeight = {}'.format(loss_weight_dict)) Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc())) Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc())) Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc())) top1_acc = yaml.load(self.running_score.get_top1_acc()) for key in top1_acc: if 'src0_label0' in key: self.runner_state['performance'] = top1_acc[key] Log.info('Use acc of {} to compare performace'.format(key)) break self.running_score.reset() self.val_losses.reset() self.batch_time.reset() self.cls_net.train()
class Trainer(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.running_score = RunningScore(configer) self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.data_loader = DataLoader(configer) self.optim_scheduler = OptimScheduler(configer) self.seg_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) Log.info('Params Group Method: {}'.format( self.configer.get('optim', 'group_method'))) if self.configer.get('optim', 'group_method') == 'decay': params_group = self.group_weight(self.seg_net) else: assert self.configer.get('optim', 'group_method') is None params_group = self._get_parameters() self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer( params_group) self.train_loader = self.data_loader.get_trainloader() self.val_loader = self.data_loader.get_valloader() self.pixel_loss = self.loss_manager.get_seg_loss() @staticmethod def group_weight(module): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) else: if hasattr(m, 'weight'): group_no_decay.append(m.weight) if hasattr(m, 'bias'): group_no_decay.append(m.bias) assert len(list( module.parameters())) == len(group_decay) + len(group_no_decay) groups = [ dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0) ] return groups def _get_parameters(self): bb_lr = [] nbb_lr = [] params_dict = dict(self.seg_net.named_parameters()) for key, value in params_dict.items(): if 'backbone' not in key: nbb_lr.append(value) else: bb_lr.append(value) params = [{ 'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr') }, { 'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult') }] return params def __train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() start_time = time.time() for i, data_dict in enumerate(self.train_loader): if self.configer.get('lr', 'metric') == 'iters': self.scheduler.step(self.configer.get('iters')) else: self.scheduler.step(self.configer.get('epoch')) if self.configer.get('lr', 'is_warm'): self.module_runner.warm_lr(self.configer.get('iters'), self.scheduler, self.optimizer, backbone_list=[ 0, ]) inputs = data_dict['img'] targets = data_dict['labelmap'] self.data_time.update(time.time() - start_time) # Change the data type. # inputs, targets = self.module_runner.to_device(inputs, targets) # Forward pass. outputs = self.seg_net(inputs) # outputs = self.module_utilizer.gather(outputs) # Compute the loss of the train batch & backward. loss = self.pixel_loss(outputs, targets, gathered=self.configer.get( 'network', 'gathered')) if self.configer.exists('train', 'loader') and self.configer.get( 'train', 'loader') == 'ade20k': batch_size = self.configer.get( 'train', 'batch_size') * self.configer.get( 'train', 'batch_per_gpu') self.train_losses.update(loss.item(), batch_size) else: self.train_losses.update(loss.item(), inputs.size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get( 'solver', 'display_iter') == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n' .format(self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.configer.get('iters') == self.configer.get( 'solver', 'max_iters'): break # Check to val the current model. if self.configer.get('iters') % self.configer.get( 'solver', 'test_interval') == 0: self.__val() self.configer.plus_one('epoch') def __val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): inputs = data_dict['img'] targets = data_dict['labelmap'] with torch.no_grad(): # Change the data type. inputs, targets = self.module_runner.to_device(inputs, targets) # Forward pass. outputs = self.seg_net(inputs) # Compute the loss of the val batch. loss = self.pixel_loss(outputs, targets, gathered=self.configer.get( 'network', 'gathered')) outputs = self.module_runner.gather(outputs) self.val_losses.update(loss.item(), inputs.size(0)) self._update_running_score(outputs[-1], data_dict['meta']) # self.seg_running_score.update(pred.max(1)[1].cpu().numpy(), targets.cpu().numpy()) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.update(['performance'], self.running_score.get_mean_iou()) self.configer.update(['val_loss'], self.val_losses.avg) self.module_runner.save_net(self.seg_net, save_mode='performance') self.module_runner.save_net(self.seg_net, save_mode='val_loss') # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time, loss=self.val_losses)) Log.info('Mean IOU: {}\n'.format(self.running_score.get_mean_iou())) Log.info('Pixel ACC: {}\n'.format(self.running_score.get_pixel_acc())) self.batch_time.reset() self.val_losses.reset() self.running_score.reset() self.seg_net.train() def _update_running_score(self, pred, metas): pred = pred.permute(0, 2, 3, 1) for i in range(pred.size(0)): ori_img_size = metas[i]['ori_img_size'] border_size = metas[i]['border_size'] ori_target = metas[i]['ori_target'] total_logits = cv2.resize( pred[i, :border_size[1], :border_size[0]].cpu().numpy(), tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) labelmap = np.argmax(total_logits, axis=-1) self.running_score.update(labelmap[None], ori_target[None]) def train(self): # cudnn.benchmark = True if self.configer.get('network', 'resume') is not None and self.configer.get( 'network', 'resume_val'): self.__val() while self.configer.get('iters') < self.configer.get( 'solver', 'max_iters'): self.__train() self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))