class AutoSKUMerger(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.cls_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self.runner_state = None self.round = 1 self._relabel() def _relabel(self): label_id = 0 label_dict = dict() old_label_path = self.configer.get('data', 'label_path') new_label_path = '{}_new'.format(self.configer.get('data', 'label_path')) self.configer.update('data.label_path', new_label_path) fw = open(new_label_path, 'w') check_valid_dict = dict() with open(old_label_path, 'r') as fr: for line in fr.readlines(): line_items = line.strip().split() if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])): continue if line_items[1] not in label_dict: label_dict[line_items[1]] = label_id label_id += 1 if line_items[0] in check_valid_dict: Log.error('Duplicate Error: {}'.format(line_items[0])) exit() check_valid_dict[line_items[0]] = 1 fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]])) fw.close() shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt')) self.configer.update(('data.num_classes'), [label_id]) Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes'))) def _init(self): self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.cls_model_manager = ModelManager(self.configer) self.cls_data_loader = DataLoader(self.configer) self.cls_running_score = RunningScore(self.configer) self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.cls_net = self.cls_model_manager.get_model() self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.solver_dict = self.configer.get(self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net, self.optimizer = RunnerHelper.to_dtype(self, self.cls_net, self.optimizer) self.train_loader = self.cls_data_loader.get_trainloader() self.val_loader = self.cls_data_loader.get_valloader() self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{'params': lr_1, 'lr': self.solver_dict['lr']['base_lr']*self.configer.get('network', 'bb_lr_scale')}, {'params': lr_2, 'lr': self.solver_dict['lr']['base_lr']}] return params def _merge_class(self, cmatrix, fscore_list): Log.info('Merging class...') Log.info('Avg F1-score: {}'.format(fscore_list[-1])) threshold = max(self.configer.get('merge', 'min_thre'), self.configer.get('merge', 'max_thre') - self.configer.get('merge', 'round_decay') * self.round) h, w = cmatrix.shape[0], cmatrix.shape[1] per_class_num = np.sum(cmatrix, 1) pairs_list = list() pre_dict = dict() for i in range(h): for j in range(w): if i == j: continue if cmatrix[i][j] * 1.0 / per_class_num[i] > threshold: pairs_list.append([i, j]) pre_dict[i] = i pre_dict[j] = j for pair in pairs_list: root_node = list() for item in pair: r = item while pre_dict[r] != r: r = pre_dict[r] i = item while i != r: j = pre_dict[i] pre_dict[i] = r i = j root_node.append(r) if root_node[0] != root_node[1]: pre_dict[root_node[0]] = root_node[1] pairs_dict = dict() for k in pre_dict.keys(): v = k while pre_dict[v] != v: v = pre_dict[v] if v != k: if v not in pairs_dict: pairs_dict[v] = [k] else: pairs_dict[v].append(k) mutual_pairs_dict = {} for k, v in pairs_dict.items(): mutual_pairs_dict[k] = v if len(v) > 1: # multi relation for p in v: mutual_pairs_dict[p] = [k] for q in v: if p != q: mutual_pairs_dict[p].append(q) else: mutual_pairs_dict[v[0]] = [k] # mutual relation id_map_list = [-1] * self.configer.get('data', 'num_classes')[0] label_cnt = 0 for i in range(self.configer.get('data', 'num_classes')[0]): if id_map_list[i] != -1: continue power = self.round / self.configer.get('merge', 'max_round') if self.configer.get('merge', 'enable_fscore') and \ fscore_list[i] / fscore_list[-1] < self.configer.get('merge', 'fscore_ratio') * power: continue id_map_list[i] = label_cnt if i in mutual_pairs_dict: for v in mutual_pairs_dict[i]: assert id_map_list[v] == -1 id_map_list[v] = label_cnt label_cnt += 1 fw = open('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), 'w') with open(self.configer.get('data', 'label_path'), 'r') as fr: for line in fr.readlines(): path, label = line.strip().split() if id_map_list[int(label)] == -1: continue map_label = id_map_list[int(label)] fw.write('{} {}\n'.format(path, map_label)) fw.close() shutil.move('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), self.configer.get('data', 'label_path')) shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'label_{}.txt'.format(self.round))) old_label_cnt = self.configer.get('data', 'num_classes')[0] self.configer.update('data.num_classes', [label_cnt]) return old_label_cnt - label_cnt def run(self): last_acc = 0.0 while self.round <= self.configer.get('merge', 'max_round'): Log.info('Merge Round: {}'.format(self.round)) Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes'))) self._init() self.train() acc, cmatrix, fscore_list = self.val(self.cls_data_loader.get_valloader()) merge_cnt = self._merge_class(cmatrix, fscore_list) if merge_cnt < self.configer.get('merge', 'cnt_thre') \ or (acc - last_acc) < self.configer.get('merge', 'acc_thre'): break last_acc = acc self.round += 1 shutil.copy(self.configer.get('data', 'label_path'), os.path.join(self.configer.get('data', 'merge_dir'), 'merge_label.txt')) self._init() self.train() Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes'))) def train(self): """ Train function of every epoch during train phase. """ self.cls_net.train() start_time = time.time() while self.runner_state['iters'] < self.solver_dict['max_iters']: # Adjust the learning rate after every epoch. self.runner_state['epoch'] += 1 for i, data_dict in enumerate(self.train_loader): Trainer.update(self, solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) # Change the data type. # Forward pass. out = self.cls_net(data_dict) # Compute the loss of the train batch & backward. loss_dict = self.loss(out) loss = loss_dict['loss'] self.train_losses.update(loss.item(), data_dict['img'].size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict['display_iter'] == 0: Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr']['metric'] == 'iters' and self.runner_state['iters'] == self.solver_dict['max_iters']: self.val() break # Check to val the current model. if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0: self.val() def val(self, loader=None): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() loader = self.val_loader if loader is None else loader list_y_true, list_y_pred = [], [] with torch.no_grad(): for j, data_dict in enumerate(loader): out = self.cls_net(data_dict) loss_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) # Compute the loss of the val batch. self.cls_running_score.update(out_dict, label_dict) y_true = label_dict['out0'].view(-1).cpu().numpy().tolist() y_pred = out_dict['out0'].max(1)[1].view(-1).cpu().numpy().tolist() list_y_true.extend(y_true) list_y_pred.extend(y_pred) self.val_losses.update(loss_dict['loss'].mean().item(), data_dict['img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() RunnerHelper.save_net(self, self.cls_net, performance=self.cls_running_score.top1_acc.avg['out0']) self.runner_state['performance'] = self.cls_running_score.top1_acc.avg['out0'] # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time)) Log.info('Test Set: {} images'.format(len(list_y_true))) Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses)) Log.info('Top1 ACC = {}'.format(self.cls_running_score.top1_acc.avg['out0'])) # Log.info('Top5 ACC = {}'.format(self.cls_running_score.get_top5_acc())) acc= self.cls_running_score.top1_acc.avg['out0'] cmatrix = confusion_matrix(list_y_true, list_y_pred) fscore_str = classification_report(list_y_true, list_y_pred, digits=5) fscore_list = [float(line.strip().split()[-2]) for line in fscore_str.split('\n')[2:] if len(line.strip().split()) > 0] self.batch_time.reset() self.val_losses.reset() self.cls_running_score.reset() self.cls_net.train() return acc, cmatrix, fscore_list
class ImageTester(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.optim_scheduler = OptimScheduler(configer) self.seg_data_loader = DataLoader(configer) self.save_dir = self.configer.get('test', 'out_dir') self.seg_net = None self.test_loader = None self.test_size = None self.infer_time = 0 self.infer_cnt = 0 self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) print(f"self.save_dir {self.save_dir}") # if 'test' in self.save_dir: # self.test_loader = self.seg_data_loader.get_testloader() # self.test_size = len(self.test_loader) * self.configer.get('test', 'batch_size') # print(f"self.test_size {self.test_size}") # else: # self.test_loader = self.seg_data_loader.get_valloader() # self.test_size = len(self.test_loader) * self.configer.get('val', 'batch_size') self.seg_net.eval() def __relabel(self, label_map): height, width = label_map.shape label_dst = np.zeros((height, width), dtype=np.uint8) for i in range(self.configer.get('data', 'num_classes')): label_dst[label_map == i] = self.configer.get('data', 'label_list')[i] label_dst = np.array(label_dst, dtype=np.uint8) return label_dst def test(self, img_path=None, output_dir=None, data_loader=None): """ Validation function during the train phase. """ print("test!!!") self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) colors = get_ade_colors() # Reader. if img_path is not None: input_path = img_path else: input_path = self.configer.get('input_image') input_image = cv2.imread(input_path) transform = trans.Compose([ trans.ToTensor(), trans.Normalize(div_value=self.configer.get('normalize', 'div_value'), mean=self.configer.get('normalize', 'mean'), std=self.configer.get('normalize', 'std')), ]) aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val') pre_vis_img = None pre_lines = None pre_target_img = None ori_img = input_image.copy() h, w, _ = input_image.shape ori_img_size = [w, h] # print(img.shape) input_image = aug_val_transform(input_image) input_image = input_image[0] h, w, _ = input_image.shape border_size = [w, h] input_image = transform(input_image) # print(img) # print(img.shape) # inputs = data_dict['img'] # names = data_dict['name'] # metas = data_dict['meta'] # print(inputs) with torch.no_grad(): # Forward pass. outputs = self.ss_test([input_image]) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs] n = len(outputs) logits = cv2.resize(outputs[0], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') input_name = '.'.join(os.path.basename(input_path).split('.')[:-1]) if output_dir is None: label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name)) else: label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name)) FileHelper.make_dirs(label_path, is_file=True) # print(f"{label_path}") ImageHelper.save(label_img_, label_path) self.batch_time.update(time.time() - start_time) # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time)) def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) start = timeit.default_timer() outputs = self.seg_net.forward(inputs, offset_h_maps, offset_w_maps) torch.cuda.synchronize() end = timeit.default_timer() if (self.configer.get('loss', 'loss_type') == "fs_auxce_loss") or (self.configer.get('loss', 'loss_type') == "triple_auxce_loss"): outputs = outputs[-1] elif self.configer.get('loss', 'loss_type') == "pyramid_auxce_loss": outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4] outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True) return outputs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def ss_test(self, inputs, scale=1): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) scaled_inputs = F.interpolate(inputs, size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True) start = timeit.default_timer() outputs = self.seg_net.forward(scaled_inputs) torch.cuda.synchronize() end = timeit.default_timer() outputs = outputs[-1] outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True) return outputs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') replicas = nn.parallel.replicate(self.seg_net.module, device_ids) scaled_inputs, ori_size, outputs = [], [], [] for i, d in zip(inputs, device_ids): h, w = i.size(1), i.size(2) ori_size.append((h, w)) i = F.interpolate(i.unsqueeze(0), size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True) scaled_inputs.append(i.cuda(d, non_blocking=True)) scaled_outputs = nn.parallel.parallel_apply(replicas[:len(scaled_inputs)], scaled_inputs) for i, output in enumerate(scaled_outputs): outputs.append(F.interpolate(output[-1], size=ori_size[i], mode='bilinear', align_corners=True)) return outputs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def flip(self, x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] def sscrop_test(self, inputs, crop_size, scale=1): ''' Currently, sscrop_test does not support diverse_size testing ''' n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) scaled_inputs = F.interpolate(inputs, size=(int(ori_h*scale), int(ori_w*scale)), mode="bilinear", align_corners=True) n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(1), scaled_inputs.size(2), scaled_inputs.size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) count_predictions = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) crop_counter = 0 height_starts = self._decide_intersection(h, crop_size[0]) width_starts = self._decide_intersection(w, crop_size[1]) for height in height_starts: for width in width_starts: crop_inputs = scaled_inputs[:, :, height:height+crop_size[0], width:width + crop_size[1]] prediction = self.ss_test(crop_inputs) count_predictions[:, :, height:height+crop_size[0], width:width + crop_size[1]] += 1 full_probs[:, :, height:height+crop_size[0], width:width + crop_size[1]] += prediction crop_counter += 1 Log.info('predicting {:d}-th crop'.format(crop_counter)) full_probs /= count_predictions full_probs = F.interpolate(full_probs, size=(ori_h, ori_w), mode='bilinear', align_corners=True) return full_probs def ms_test(self, inputs): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) if self.configer.exists('test', 'scale_weights'): for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += weight * probs return full_probs else: for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs return full_probs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True) for index, i in enumerate(inputs)] flip_inputs = [self.flip(i, 2) for i in inputs] if self.configer.exists('test', 'scale_weights'): for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(flip_inputs, scale) for i in range(len(inputs)): full_probs[i] += weight * (probs[i] + self.flip(flip_probs[i], 3)) return full_probs else: for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(flip_inputs, scale) for i in range(len(inputs)): full_probs[i] += (probs[i] + self.flip(flip_probs[i], 3)) return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def ms_test_depth(self, inputs, names): prob_list = [] scale_list = [] if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) prob_list.append(probs) scale_list.append(scale) full_probs = self.fuse_with_depth(prob_list, scale_list, names) return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def fuse_with_depth(self, probs, scales, names): MAX_DEPTH = 63 POWER_BASE = 0.8 if 'test' in self.save_dir: stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/" else: stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/" n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(2), probs[0].size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for index, name in enumerate(names): stereo_map = cv2.imread(stereo_path + name + '.png', -1) depth_map = stereo_map / 256.0 depth_map = 0.5 / depth_map depth_map = 500 * depth_map depth_map = np.clip(depth_map, 0, MAX_DEPTH) depth_map = depth_map // (MAX_DEPTH // len(scales)) for prob, scale in zip(probs, scales): scale_index = self._locate_scale_index(scale, scales) weight_map = np.abs(depth_map - scale_index) weight_map = np.power(POWER_BASE, weight_map) weight_map = cv2.resize(weight_map, (w, h)) full_probs[index, :, :, :] += torch.from_numpy(np.expand_dims(weight_map, axis=0)).type(torch.cuda.FloatTensor) * prob[index, :, :, :] return full_probs @staticmethod def _locate_scale_index(scale, scales): for idx, s in enumerate(scales): if scale == s: return idx return 0 def ms_test_wo_flip(self, inputs): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) full_probs += probs return full_probs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True) for index, i, in enumerate(inputs)] for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) for i in range(len(inputs)): full_probs[i] += probs[i] return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def mscrop_test(self, inputs, crop_size): ''' Currently, mscrop_test does not support diverse_size testing ''' n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3) full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): Log.info('Scale {0:.2f} prediction'.format(scale)) if scale < 1: probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs else: probs = self.sscrop_test(inputs, crop_size, scale) flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size, scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs return full_probs def _decide_intersection(self, total_length, crop_length): stride = crop_length times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride*i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting def dense_crf_process(self, images, outputs): ''' Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py ''' # hyperparameters of the dense crf # baseline = 79.5 # bi_xy_std = 67, 79.1 # bi_xy_std = 20, 79.6 # bi_xy_std = 10, 79.7 # bi_xy_std = 10, iter_max = 20, v4 79.7 # bi_xy_std = 10, iter_max = 5, v5 79.7 # bi_xy_std = 5, v3 79.7 iter_max = 10 pos_w = 3 pos_xy_std = 1 bi_w = 4 bi_xy_std = 10 bi_rgb_std = 3 b = images.size(0) mean_vector = np.expand_dims(np.expand_dims(np.transpose(np.array([102.9801, 115.9465, 122.7717])), axis=1), axis=2) outputs = F.softmax(outputs, dim=1) for i in range(b): unary = outputs[i].data.cpu().numpy() C, H, W = unary.shape unary = dcrf_utils.unary_from_softmax(unary) unary = np.ascontiguousarray(unary) image = np.ascontiguousarray(images[i]) + mean_vector image = image.astype(np.ubyte) image = np.ascontiguousarray(image.transpose(1, 2, 0)) d = dcrf.DenseCRF2D(W, H, C) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w) d.addPairwiseBilateral(sxy=bi_xy_std, srgb=bi_rgb_std, rgbim=image, compat=bi_w) out_crf = np.array(d.inference(iter_max)) outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W) return outputs def visualize(self, label_img): img = label_img.copy() img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) img_num = 14 img[img == img_num] = 0 out_img = img.copy() for label in HYUNDAI_POC_CATEGORIES: red, green, blue = img[:,:,0], img[:,:,1], img[:,:,2] mask = red == label['id'] out_img[:,:,:3][mask] = label['color'] out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR) return out_img def get_ratio_all(self, anno_img): total_size = 0 lines = [] for label in HYUNDAI_POC_CATEGORIES: total = self.get_ratio(anno_img.copy(), label) lines.append([label['name'], total]) total_size += total for l in lines: if total_size: l[1] = l[1] / total_size * 100 else: l[1] = 0 return lines def get_ratio(self, anno_img, label): total = 0 label_id = label['id'] if label_id == 14: return total label_img = (anno_img == label_id).astype(np.uint8) # label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY) total = np.count_nonzero(label_img) return total def visualize_ratio(self, ratios, video_size, ratio_w): ratio_list = ratios.copy() ratio_list.insert(0, ['등급','비율']) RATIO_IMG_W = ratio_w RATIO_IMG_H = int(video_size[0]) TEXT_MARGIN_H = 20 TEXT_MARGIN_W = 10 row_count = 14 col_count = 2 ratio_img = np.full((RATIO_IMG_H, RATIO_IMG_W, 3), 255, np.uint8) row_h = RATIO_IMG_H / row_count col_w = RATIO_IMG_H / row_count center_w = RATIO_IMG_W / 2 for i in range(1, row_count): p_y = int(i * row_h) p_y_n = int((i+1) * row_h) for label in HYUNDAI_POC_CATEGORIES: if label['id'] == i: cv2.rectangle(ratio_img, (0, p_y), (int(center_w), p_y_n), label['color'], cv2.FILLED) for i in range(1, row_count): p_y = int(i * row_h) cv2.line(ratio_img, (0, p_y), (RATIO_IMG_W, p_y), (0,0,0)) cv2.line(ratio_img, (int(center_w), 0), (int(center_w), RATIO_IMG_H), (0,0,0)) for i in range(row_count): p_y = int(i * row_h) + TEXT_MARGIN_H p_w = int(center_w) + TEXT_MARGIN_W ratio_img = Image.fromarray(ratio_img) font = ImageFont.truetype("NanumGothic.ttf", 15) draw = ImageDraw.Draw(ratio_img) color = (0, 0, 0) # print(ratio_list) draw.text((0, p_y), ratio_list[i][0], font=font, fill=color) if isinstance(ratio_list[i][1], str): draw.text((p_w, p_y), ratio_list[i][1],font=font,fill=color) else: draw.text((p_w, p_y), "{:.02f}".format(ratio_list[i][1]),font=font,fill=color) ratio_img = np.array(ratio_img) ratio_img = cv2.cvtColor(ratio_img, cv2.COLOR_RGB2BGR) return ratio_img
class Tester(object): def __init__(self, configer): self.crop_size = configer.get('train', 'data_transformer')['input_size'] val_trans_seq = [ x for x in configer.get('val_trans', 'trans_seq') if 'random' not in x ] configer.update(('val_trans', 'trans_seq'), val_trans_seq) configer.get('val', 'data_transformer')['input_size'] = configer.get( 'test', 'data_transformer').get('input_size', None) configer.update(('train', 'data_transformer'), configer.get('val', 'data_transformer')) configer.update(('val', 'batch_size'), int(os.environ.get('batch_size', 16))) configer.update(('test', 'batch_size'), int(os.environ.get('batch_size', 16))) self.save_dir = configer.get('test', 'out_dir') self.dataset_name = configer.get('test', 'eval_set') self.sscrop = configer.get('test', 'sscrop') self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.seg_data_loader = DataLoader(configer) self.seg_net = None self.test_loader = None self.test_size = None self.infer_time = 0 self.infer_cnt = 0 self._init_model() pprint.pprint(configer.params_root) def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) assert self.dataset_name in ('train', 'val', 'test'), 'Cannot infer dataset name' self.size_mode = self.configer.get(self.dataset_name, 'data_transformer')['size_mode'] if self.dataset_name != 'test': self.test_loader = self.seg_data_loader.get_valloader( self.dataset_name) else: self.test_loader = self.seg_data_loader.get_testloader( self.dataset_name) self.test_size = len(self.test_loader) * self.configer.get( 'val', 'batch_size') def test(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) print('Total batches', len(self.test_loader)) for j, data_dict in enumerate(self.test_loader): inputs = [data_dict['img']] names = data_dict['name'] metas = data_dict['meta'] dest_dir = self.save_dir with torch.no_grad(): offsets, logits = self.extract_offset(inputs) print([x.shape for x in logits]) for k in range(len(inputs[0])): image_id += 1 ori_img_size = metas[k]['ori_img_size'] border_size = metas[k]['border_size'] offset = offsets[k].squeeze().cpu().numpy() offset = cv2.resize( offset[:border_size[1], :border_size[0]], tuple(ori_img_size), interpolation=cv2.INTER_NEAREST) print(image_id) os.makedirs(dest_dir, exist_ok=True) if names[k].rpartition('.')[0]: dest_name = names[k].rpartition('.')[0] + '.mat' else: dest_name = names[k] + '.mat' dest_name = os.path.join(dest_dir, dest_name) print('Shape:', offset.shape, 'Saving to', dest_name) data_dict = {'mat': offset} scipy.io.savemat(dest_name, data_dict, do_compression=True) try: scipy.io.loadmat(dest_name) except Exception as e: print(e) scipy.io.savemat(dest_name, data_dict, do_compression=False) self.batch_time.update(time.time() - start_time) start_time = time.time() Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) def extract_offset(self, inputs): if self.sscrop: outputs = self.sscrop_test(inputs, self.crop_size) elif self.configer.get('test', 'mode') == 'ss_test': outputs = self.ss_test(inputs) offsets = [] logits = [] for mask_logits, dir_logits, img in zip(*outputs[:2], inputs[0]): h, w = img.shape[1:] mask_logits = F.interpolate(mask_logits.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True) dir_logits = F.interpolate(dir_logits.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True) logit = torch.softmax(dir_logits, dim=1) zero_mask = mask_logits.argmax(dim=1, keepdim=True) == 0 logits.append(mask_logits[:, 1]) offset = self._get_offset(mask_logits, dir_logits) offsets.append(offset) print([x.shape for x in offsets]) return offsets, logits def _get_offset(self, mask_logits, dir_logits): edge_mask = mask_logits[:, 1] > 0.5 dir_logits = torch.softmax(dir_logits, dim=1) n, _, h, w = dir_logits.shape keep_mask = edge_mask dir_label = torch.argmax(dir_logits, dim=1).float() offset = DTOffsetHelper.label_to_vector(dir_label) offset = offset.permute(0, 2, 3, 1) offset[~keep_mask, :] = 0 return offset def _flip(self, x, dim=-1): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] def _flip_offset(self, x): x = self._flip(x, dim=-1) if len(x.shape) == 4: return x[:, DTOffsetHelper.flipping_indices()] else: return x[DTOffsetHelper.flipping_indices()] def _flip_inputs(self, inputs): if self.size_mode == 'fix_size': return [self._flip(x, -1) for x in inputs] else: return [[self._flip(x, -1) for x in xs] for xs in inputs] def _flip_outputs(self, outputs): funcs = [self._flip, self._flip_offset] if self.size_mode == 'fix_size': return [f(x) for f, x in zip(funcs, outputs)] else: return [[f(x) for x in xs] for f, xs in zip(funcs, outputs)] def _tuple_sum(self, tup1, tup2, tup2_weight=1): """ tup1 / tup2: tuple of tensors or tuple of list of tensors """ if tup1 is None: if self.size_mode == 'fix_size': return [y * tup2_weight for y in tup2] else: return [[y * tup2_weight for y in ys] for ys in tup2] else: if self.size_mode == 'fix_size': return [x + y * tup2_weight for x, y in zip(tup1, tup2)] else: return [[x + y * tup2_weight for x, y in zip(xs, ys)] for xs, ys in zip(tup1, tup2)] def _scale_ss_inputs(self, inputs, scale): n, c, h, w = inputs[0].shape size = (int(h * scale), int(w * scale)) return [ F.interpolate(inputs[0], size=size, mode="bilinear", align_corners=True), ], (h, w) def sscrop_test(self, inputs, crop_size, scale=1): ''' Currently, sscrop_test does not support diverse_size testing ''' scaled_inputs = inputs img = scaled_inputs[0] n, c, h, w = img.size(0), img.size(1), img.size(2), img.size(3) ori_h, ori_w = h, w full_probs = [ torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8) ] count_predictions = [ torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8) ] crop_counter = 0 height_starts = self._decide_intersection(h, crop_size[0]) width_starts = self._decide_intersection(w, crop_size[1]) for height in height_starts: for width in width_starts: crop_inputs = [ x[..., height:height + crop_size[0], width:width + crop_size[1]] for x in scaled_inputs ] prediction = self.ss_test(crop_inputs) for j in range(2): count_predictions[j][:, :, height:height + crop_size[0], width:width + crop_size[1]] += 1 full_probs[j][:, :, height:height + crop_size[0], width:width + crop_size[1]] += prediction[j] crop_counter += 1 Log.info('predicting {:d}-th crop'.format(crop_counter)) for j in range(2): full_probs[j] /= count_predictions[j] full_probs[j] = F.interpolate(full_probs[j], size=(ori_h, ori_w), mode='bilinear', align_corners=True) return full_probs def _scale_ss_outputs(self, outputs, size): return [ F.interpolate(x, size=size, mode="bilinear", align_corners=True) for x in outputs ] def ss_test(self, inputs, scale=1): if self.size_mode == 'fix_size': scaled_inputs, orig_size = self._scale_ss_inputs(inputs, scale) print([x.shape for x in scaled_inputs]) start = timeit.default_timer() outputs = list(self.seg_net.forward(*scaled_inputs)) if len(outputs) == 3: outputs = (outputs[0], outputs[2]) else: outputs[0] = F.softmax(outputs[0], dim=1) torch.cuda.synchronize() end = timeit.default_timer() return self._scale_ss_outputs(outputs, orig_size) else: device_ids = self.configer.get('gpu') replicas = nn.parallel.replicate(self.seg_net.module, device_ids) scaled_inputs, ori_sizes, outputs = [], [], [] for *i, d in zip(*inputs, device_ids): scaled_i, ori_size_i = self._scale_ss_inputs( [x.unsqueeze(0) for x in i], scale) scaled_inputs.append( [x.cuda(d, non_blocking=True) for x in scaled_i]) ori_sizes.append(ori_size_i) scaled_outputs = nn.parallel.parallel_apply( replicas[:len(scaled_inputs)], scaled_inputs) for o, ori_size in zip(scaled_outputs, ori_sizes): o = self._scale_ss_outputs(o, ori_size) if len(o) == 3: o = (o[0], o[2]) outputs.append([x.squeeze(0) for x in o]) outputs = list(map(list, zip(*outputs))) return outputs def _decide_intersection(self, total_length, crop_length, crop_stride_ratio=1 / 3): stride = int(crop_length * crop_stride_ratio) # set the stride as the paper do times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride * i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting
class Tester(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.optim_scheduler = OptimScheduler(configer) self.seg_data_loader = DataLoader(configer) self.save_dir = self.configer.get('test', 'out_dir') self.seg_net = None self.test_loader = None self.test_size = None self.infer_time = 0 self.infer_cnt = 0 self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) if 'test' in self.save_dir: self.test_loader = self.seg_data_loader.get_testloader() self.test_size = len(self.test_loader) * self.configer.get( 'test', 'batch_size') else: self.test_loader = self.seg_data_loader.get_valloader() self.test_size = len(self.test_loader) * self.configer.get( 'val', 'batch_size') self.seg_net.eval() def __relabel(self, label_map): height, width = label_map.shape label_dst = np.zeros((height, width), dtype=np.uint8) for i in range(self.configer.get('data', 'num_classes')): label_dst[label_map == i] = self.configer.get( 'data', 'label_list')[i] label_dst = np.array(label_dst, dtype=np.uint8) return label_dst def test(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() image_id = 0 Log.info('save dir {}'.format(self.save_dir)) FileHelper.make_dirs(self.save_dir, is_file=False) if self.configer.get('dataset') in ['cityscapes', 'gta5']: colors = get_cityscapes_colors() elif self.configer.get('dataset') == 'ade20k': colors = get_ade_colors() elif self.configer.get('dataset') == 'lip': colors = get_lip_colors() elif self.configer.get('dataset') == 'pascal_context': colors = get_pascal_context_colors() elif self.configer.get('dataset') == 'pascal_voc': colors = get_pascal_voc_colors() elif self.configer.get('dataset') == 'coco_stuff': colors = get_cocostuff_colors() else: raise RuntimeError("Unsupport colors") save_prob = False if self.configer.get('test', 'save_prob'): save_prob = self.configer.get('test', 'save_prob') def softmax(X, axis=0): max_prob = np.max(X, axis=axis, keepdims=True) X -= max_prob X = np.exp(X) sum_prob = np.sum(X, axis=axis, keepdims=True) X /= sum_prob return X for j, data_dict in enumerate(self.test_loader): inputs = data_dict['img'] names = data_dict['name'] metas = data_dict['meta'] if 'val' in self.save_dir and os.environ.get('save_gt_label'): labels = data_dict['labelmap'] with torch.no_grad(): # Forward pass. if self.configer.exists('data', 'use_offset') and self.configer.get( 'data', 'use_offset') == 'offline': offset_h_maps = data_dict['offsetmap_h'] offset_w_maps = data_dict['offsetmap_w'] outputs = self.offset_test(inputs, offset_h_maps, offset_w_maps) elif self.configer.get('test', 'mode') == 'ss_test': outputs = self.ss_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test': outputs = self.ms_test(inputs) elif self.configer.get('test', 'mode') == 'ms_test_depth': outputs = self.ms_test_depth(inputs, names) elif self.configer.get('test', 'mode') == 'sscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.sscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'mscrop_test': crop_size = self.configer.get('test', 'crop_size') outputs = self.mscrop_test(inputs, crop_size) elif self.configer.get('test', 'mode') == 'crf_ss_test': outputs = self.ss_test(inputs) outputs = self.dense_crf_process(inputs, outputs) if isinstance(outputs, torch.Tensor): outputs = outputs.permute(0, 2, 3, 1).cpu().numpy() n = outputs.shape[0] else: outputs = [ output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs ] n = len(outputs) for k in range(n): image_id += 1 ori_img_size = metas[k]['ori_img_size'] border_size = metas[k]['border_size'] logits = cv2.resize( outputs[k][:border_size[1], :border_size[0]], tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) # save the logits map if self.configer.get('test', 'save_prob'): prob_path = os.path.join(self.save_dir, "prob/", '{}.npy'.format(names[k])) FileHelper.make_dirs(prob_path, is_file=True) np.save(prob_path, softmax(logits, axis=-1)) label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8) if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = label_img + 1 label_img = label_img.astype(np.uint8) if self.configer.exists('data', 'label_list'): label_img_ = self.__relabel(label_img) else: label_img_ = label_img label_img_ = Image.fromarray(label_img_, 'P') Log.info('{:4d}/{:4d} label map generated'.format( image_id, self.test_size)) label_path = os.path.join(self.save_dir, "label/", '{}.png'.format(names[k])) FileHelper.make_dirs(label_path, is_file=True) ImageHelper.save(label_img_, label_path) # colorize the label-map if os.environ.get('save_gt_label'): if self.configer.exists( 'data', 'reduce_zero_label') and self.configer.get( 'data', 'reduce_zero_label'): label_img = labels[k] + 1 label_img = np.asarray(label_img, dtype=np.uint8) color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "gt_vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) else: color_img_ = Image.fromarray(label_img) color_img_.putpalette(colors) vis_path = os.path.join(self.save_dir, "vis/", '{}.png'.format(names[k])) FileHelper.make_dirs(vis_path, is_file=True) ImageHelper.save(color_img_, save_path=vis_path) self.batch_time.update(time.time() - start_time) start_time = time.time() # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) start = timeit.default_timer() outputs = self.seg_net.forward(inputs, offset_h_maps, offset_w_maps) torch.cuda.synchronize() end = timeit.default_timer() if (self.configer.get('loss', 'loss_type') == "fs_auxce_loss") or (self.configer.get( 'loss', 'loss_type') == "triple_auxce_loss"): outputs = outputs[-1] elif self.configer.get('loss', 'loss_type') == "pyramid_auxce_loss": outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4] outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True) return outputs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def ss_test(self, inputs, scale=1): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) scaled_inputs = F.interpolate(inputs, size=(int(h * scale), int(w * scale)), mode="bilinear", align_corners=True) start = timeit.default_timer() outputs = self.seg_net.forward(scaled_inputs) torch.cuda.synchronize() end = timeit.default_timer() outputs = outputs[-1] outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True) return outputs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') replicas = nn.parallel.replicate(self.seg_net.module, device_ids) scaled_inputs, ori_size, outputs = [], [], [] for i, d in zip(inputs, device_ids): h, w = i.size(1), i.size(2) ori_size.append((h, w)) i = F.interpolate(i.unsqueeze(0), size=(int(h * scale), int(w * scale)), mode="bilinear", align_corners=True) scaled_inputs.append(i.cuda(d, non_blocking=True)) scaled_outputs = nn.parallel.parallel_apply( replicas[:len(scaled_inputs)], scaled_inputs) for i, output in enumerate(scaled_outputs): outputs.append( F.interpolate(output[-1], size=ori_size[i], mode='bilinear', align_corners=True)) return outputs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def flip(self, x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] def sscrop_test(self, inputs, crop_size, scale=1): ''' Currently, sscrop_test does not support diverse_size testing ''' n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) scaled_inputs = F.interpolate(inputs, size=(int(ori_h * scale), int(ori_w * scale)), mode="bilinear", align_corners=True) n, c, h, w = scaled_inputs.size(0), scaled_inputs.size( 1), scaled_inputs.size(2), scaled_inputs.size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) count_predictions = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) crop_counter = 0 height_starts = self._decide_intersection(h, crop_size[0]) width_starts = self._decide_intersection(w, crop_size[1]) for height in height_starts: for width in width_starts: crop_inputs = scaled_inputs[:, :, height:height + crop_size[0], width:width + crop_size[1]] prediction = self.ss_test(crop_inputs) count_predictions[:, :, height:height + crop_size[0], width:width + crop_size[1]] += 1 full_probs[:, :, height:height + crop_size[0], width:width + crop_size[1]] += prediction crop_counter += 1 Log.info('predicting {:d}-th crop'.format(crop_counter)) full_probs /= count_predictions full_probs = F.interpolate(full_probs, size=(ori_h, ori_w), mode='bilinear', align_corners=True) return full_probs def ms_test(self, inputs): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) if self.configer.exists('test', 'scale_weights'): for scale, weight in zip( self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += weight * probs return full_probs else: for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs return full_probs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') full_probs = [ torch.zeros(1, self.configer.get('data', 'num_classes'), i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True) for index, i in enumerate(inputs) ] flip_inputs = [self.flip(i, 2) for i in inputs] if self.configer.exists('test', 'scale_weights'): for scale, weight in zip( self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(flip_inputs, scale) for i in range(len(inputs)): full_probs[i] += weight * (probs[i] + self.flip(flip_probs[i], 3)) return full_probs else: for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(flip_inputs, scale) for i in range(len(inputs)): full_probs[i] += (probs[i] + self.flip(flip_probs[i], 3)) return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def ms_test_depth(self, inputs, names): prob_list = [] scale_list = [] if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) prob_list.append(probs) scale_list.append(scale) full_probs = self.fuse_with_depth(prob_list, scale_list, names) return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def fuse_with_depth(self, probs, scales, names): MAX_DEPTH = 63 POWER_BASE = 0.8 if 'test' in self.save_dir: stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/" else: stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/" n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size( 2), probs[0].size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for index, name in enumerate(names): stereo_map = cv2.imread(stereo_path + name + '.png', -1) depth_map = stereo_map / 256.0 depth_map = 0.5 / depth_map depth_map = 500 * depth_map depth_map = np.clip(depth_map, 0, MAX_DEPTH) depth_map = depth_map // (MAX_DEPTH // len(scales)) for prob, scale in zip(probs, scales): scale_index = self._locate_scale_index(scale, scales) weight_map = np.abs(depth_map - scale_index) weight_map = np.power(POWER_BASE, weight_map) weight_map = cv2.resize(weight_map, (w, h)) full_probs[index, :, :, :] += torch.from_numpy( np.expand_dims(weight_map, axis=0)).type( torch.cuda.FloatTensor) * prob[index, :, :, :] return full_probs @staticmethod def _locate_scale_index(scale, scales): for idx, s in enumerate(scales): if scale == s: return idx return 0 def ms_test_wo_flip(self, inputs): if isinstance(inputs, torch.Tensor): n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) full_probs += probs return full_probs elif isinstance(inputs, collections.Sequence): device_ids = self.configer.get('gpu') full_probs = [ torch.zeros(1, self.configer.get('data', 'num_classes'), i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True) for index, i, in enumerate(inputs) ] for scale in self.configer.get('test', 'scale_search'): probs = self.ss_test(inputs, scale) for i in range(len(inputs)): full_probs[i] += probs[i] return full_probs else: raise RuntimeError("Unsupport data type: {}".format(type(inputs))) def mscrop_test(self, inputs, crop_size): ''' Currently, mscrop_test does not support diverse_size testing ''' n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) full_probs = torch.cuda.FloatTensor( n, self.configer.get('data', 'num_classes'), h, w).fill_(0) for scale in self.configer.get('test', 'scale_search'): Log.info('Scale {0:.2f} prediction'.format(scale)) if scale < 1: probs = self.ss_test(inputs, scale) flip_probs = self.ss_test(self.flip(inputs, 3), scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs else: probs = self.sscrop_test(inputs, crop_size, scale) flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size, scale) probs = probs + self.flip(flip_probs, 3) full_probs += probs return full_probs def _decide_intersection(self, total_length, crop_length): stride = crop_length times = (total_length - crop_length) // stride + 1 cropped_starting = [] for i in range(times): cropped_starting.append(stride * i) if total_length - cropped_starting[-1] > crop_length: cropped_starting.append(total_length - crop_length) # must cover the total image return cropped_starting def dense_crf_process(self, images, outputs): ''' Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py ''' # hyperparameters of the dense crf # baseline = 79.5 # bi_xy_std = 67, 79.1 # bi_xy_std = 20, 79.6 # bi_xy_std = 10, 79.7 # bi_xy_std = 10, iter_max = 20, v4 79.7 # bi_xy_std = 10, iter_max = 5, v5 79.7 # bi_xy_std = 5, v3 79.7 iter_max = 10 pos_w = 3 pos_xy_std = 1 bi_w = 4 bi_xy_std = 10 bi_rgb_std = 3 b = images.size(0) mean_vector = np.expand_dims(np.expand_dims(np.transpose( np.array([102.9801, 115.9465, 122.7717])), axis=1), axis=2) outputs = F.softmax(outputs, dim=1) for i in range(b): unary = outputs[i].data.cpu().numpy() C, H, W = unary.shape unary = dcrf_utils.unary_from_softmax(unary) unary = np.ascontiguousarray(unary) image = np.ascontiguousarray(images[i]) + mean_vector image = image.astype(np.ubyte) image = np.ascontiguousarray(image.transpose(1, 2, 0)) d = dcrf.DenseCRF2D(W, H, C) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w) d.addPairwiseBilateral(sxy=bi_xy_std, srgb=bi_rgb_std, rgbim=image, compat=bi_w) out_crf = np.array(d.inference(iter_max)) outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W) return outputs
class Trainer(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.foward_time = AverageMeter() self.backward_time = AverageMeter() self.loss_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.data_loader = DataLoader(configer) self.optim_scheduler = OptimScheduler(configer) self.data_helper = DataHelper(configer, self) self.evaluator = get_evaluator(configer, self) self.seg_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self.running_score = None self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method'))) if self.configer.get('optim', 'group_method') == 'decay': params_group = self.group_weight(self.seg_net) else: assert self.configer.get('optim', 'group_method') is None params_group = self._get_parameters() self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group) self.train_loader = self.data_loader.get_trainloader() self.val_loader = self.data_loader.get_valloader() self.pixel_loss = self.loss_manager.get_seg_loss() if is_distributed(): self.pixel_loss = self.module_runner.to_device(self.pixel_loss) @staticmethod def group_weight(module): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) else: if hasattr(m, 'weight'): group_no_decay.append(m.weight) if hasattr(m, 'bias'): group_no_decay.append(m.bias) assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] return groups def _get_parameters(self): bb_lr = [] nbb_lr = [] params_dict = dict(self.seg_net.named_parameters()) for key, value in params_dict.items(): if 'backbone' not in key: nbb_lr.append(value) else: bb_lr.append(value) params = [{'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr')}, {'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult')}] return params def __train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() self.pixel_loss.train() start_time = time.time() if "swa" in self.configer.get('lr', 'lr_policy'): normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75) swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1 if hasattr(self.train_loader.sampler, 'set_epoch'): self.train_loader.sampler.set_epoch(self.configer.get('epoch')) for i, data_dict in enumerate(self.train_loader): if self.configer.get('lr', 'metric') == 'iters': self.scheduler.step(self.configer.get('iters')) else: self.scheduler.step(self.configer.get('epoch')) if self.configer.get('lr', 'is_warm'): self.module_runner.warm_lr( self.configer.get('iters'), self.scheduler, self.optimizer, backbone_list=[0,] ) (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) self.data_time.update(time.time() - start_time) foward_start_time = time.time() outputs = self.seg_net(*inputs) self.foward_time.update(time.time() - foward_start_time) loss_start_time = time.time() if is_distributed(): import torch.distributed as dist def reduce_tensor(inp): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = get_world_size() if world_size < 2: return inp with torch.no_grad(): reduced_inp = inp dist.reduce(reduced_inp, dst=0) return reduced_inp loss = self.pixel_loss(outputs, targets) backward_loss = loss display_loss = reduce_tensor(backward_loss) / get_world_size() else: backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered')) self.train_losses.update(display_loss.item(), batch_size) self.loss_time.update(time.time() - loss_start_time) backward_start_time = time.time() self.optimizer.zero_grad() backward_loss.backward() self.optimizer.step() self.backward_time.update(time.time() - backward_start_time) # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \ (not is_distributed() or get_rank() == 0): Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t' 'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t' 'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time, foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.foward_time.reset() self.backward_time.reset() self.loss_time.reset() self.data_time.reset() self.train_losses.reset() # save checkpoints for swa if 'swa' in self.configer.get('lr', 'lr_policy') and \ self.configer.get('iters') > normal_max_iters and \ ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \ self.configer.get('iters') == self.configer.get('solver', 'max_iters')): self.optimizer.update_swa() if self.configer.get('iters') == self.configer.get('solver', 'max_iters'): break # Check to val the current model. # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0: if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0: self.__val() self.configer.plus_one('epoch') def __val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() self.pixel_loss.eval() start_time = time.time() replicas = self.evaluator.prepare_validaton() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): if j % 10 == 0: Log.info('{} images processed\n'.format(j)) if self.configer.get('dataset') == 'lip': (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True) else: (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict) with torch.no_grad(): if self.configer.get('dataset') == 'lip': inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0) outputs = self.seg_net(inputs) outputs_ = self.module_runner.gather(outputs) if isinstance(outputs_, (list, tuple)): outputs_ = outputs_[-1] outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone() outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone() if outputs_rev.shape[1] == 20: outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:] outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:] outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:] outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:] outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:] outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:] outputs_rev = torch.flip(outputs_rev, [3]) outputs = (outputs + outputs_rev) / 2. self.evaluator.update_score(outputs, data_dict['meta']) elif self.data_helper.conditions.diverse_size: outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs) for i in range(len(outputs)): loss = self.pixel_loss(outputs[i], targets[i]) self.val_losses.update(loss.item(), 1) outputs_i = outputs[i] if isinstance(outputs_i, torch.Tensor): outputs_i = [outputs_i] self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1]) else: outputs = self.seg_net(*inputs) try: loss = self.pixel_loss( outputs, targets, gathered=self.configer.get('network', 'gathered') ) except AssertionError as e: print(len(outputs), len(targets)) if not is_distributed(): outputs = self.module_runner.gather(outputs) self.val_losses.update(loss.item(), batch_size) self.evaluator.update_score(outputs, data_dict['meta']) self.batch_time.update(time.time() - start_time) start_time = time.time() self.evaluator.update_performance() self.configer.update(['val_loss'], self.val_losses.avg) self.module_runner.save_net(self.seg_net, save_mode='performance') self.module_runner.save_net(self.seg_net, save_mode='val_loss') cudnn.benchmark = True # Print the log info & reset the states. if not is_distributed() or get_rank() == 0: Log.info( 'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss {loss.avg:.8f}\n'.format( batch_time=self.batch_time, loss=self.val_losses)) self.evaluator.print_scores() self.batch_time.reset() self.val_losses.reset() self.evaluator.reset() self.seg_net.train() self.pixel_loss.train() def train(self): # cudnn.benchmark = True # self.__val() if self.configer.get('network', 'resume') is not None: if self.configer.get('network', 'resume_val'): self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) return elif self.configer.get('network', 'resume_train'): self.__val(data_loader=self.data_loader.get_valloader(dataset='train')) return # return if self.configer.get('network', 'resume') is not None and self.configer.get('network', 'resume_val'): self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) return while self.configer.get('iters') < self.configer.get('solver', 'max_iters'): self.__train() # use swa to average the model if 'swa' in self.configer.get('lr', 'lr_policy'): self.optimizer.swap_swa_sgd() self.optimizer.bn_update(self.train_loader, self.seg_net) self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) def summary(self): from lib.utils.summary import get_model_summary import torch.nn.functional as F self.seg_net.eval() for j, data_dict in enumerate(self.train_loader): print(get_model_summary(self.seg_net, data_dict['img'][0:1])) return
class MultiTaskClassifier(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = DictAverageMeter() self.val_losses = DictAverageMeter() self.cls_model_manager = ModelManager(configer) self.cls_data_loader = DataLoader(configer) self.running_score = RunningScore(configer) self.cls_net = self.cls_model_manager.get_model() self.solver_dict = self.configer.get( self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.cls_net, self.optimizer = RunnerHelper.to_dtype( self, self.cls_net, self.optimizer) self.train_loaders = dict() self.val_loaders = dict() for source in range(self.configer.get('data', 'num_data_sources')): self.train_loaders[source] = self.cls_data_loader.get_trainloader( source=source) self.val_loaders[source] = self.cls_data_loader.get_valloader( source=source) if self.configer.get('data', 'mixup'): assert (self.configer.get('data', 'num_data_sources') == 2 ), "mixup only support src0 and src1 load the same dataset" self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{ 'params': lr_1, 'lr': self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale') }, { 'params': lr_2, 'lr': self.solver_dict['lr']['base_lr'] }] return params def run(self): """ Train function of every epoch during train phase. """ if self.configer.get('network', 'resume_val'): self.val() self.cls_net.train() train_loaders = dict() for source in self.train_loaders: train_loaders[source] = iter(self.train_loaders[source]) start_time = time.time() # Adjust the learning rate after every epoch. while self.runner_state['iters'] < self.solver_dict['max_iters']: data_dict = dict() for source in train_loaders: try: tmp_data_dict = next(train_loaders[source]) # Log.info('iter={}, source={}'.format(self.runner_state['iters'], source)) except StopIteration: if source == 0 or source == '0': self.runner_state['epoch'] += 1 # Log.info('Repeat: iter={}, source={}'.format(self.runner_state['iters'], source)) train_loaders[source] = iter(self.train_loaders[source]) tmp_data_dict = next(train_loaders[source]) for k, v in tmp_data_dict.items(): data_dict['src{}_{}'.format(source, k)] = v if self.configer.get('data', 'multiscale') is not None: scale_ratios = self.configer.get('data', 'multiscale') scale_ratio = random.uniform(scale_ratios[0], scale_ratios[-1]) for key in data_dict: if key.endswith('_img'): data_dict[key] = F.interpolate( data_dict[key], scale_factor=[scale_ratio, scale_ratio], mode='bilinear', align_corners=True) if self.configer.get('data', 'mixup'): src0_resize = F.interpolate(data_dict['src0_img'], scale_factor=[ random.uniform(0.4, 0.6), random.uniform(0.4, 0.6) ], mode='bilinear', align_corners=True) b, c, h, w = src0_resize.shape pos = random.randint(0, 3) if pos == 0: # top-left data_dict['src1_img'][:, :, 0:h, 0:w] = src0_resize elif pos == 1: # top-right data_dict['src1_img'][:, :, 0:h, -w:] = src0_resize elif pos == 2: # bottom-left data_dict['src1_img'][:, :, -h:, 0:w] = src0_resize else: # bottom-right data_dict['src1_img'][:, :, -h:, -w:] = src0_resize data_dict = RunnerHelper.to_device(self, data_dict) Trainer.update( self, warm_list=(0, ), warm_lr_list=(self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale'), ), solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) # Forward pass. out = self.cls_net(data_dict) loss_dict, loss_weight_dict = self.loss(out) # Compute the loss of the train batch & backward. loss = loss_dict['loss'] self.train_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) self.optimizer.zero_grad() if self.configer.get('dtype') == 'fp16': with amp.scale_loss(loss, self.optimizer) as scaled_losses: scaled_losses.backward() else: loss.backward() if self.configer.get('network', 'clip_grad'): RunnerHelper.clip_grad(self.cls_net, 10.) self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict[ 'display_iter'] == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {5}\tLoss = {3}\nLossWeight = {4}\n'. format(self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], self.train_losses.info(), loss_weight_dict, RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr'][ 'metric'] == 'iters' and self.runner_state[ 'iters'] == self.solver_dict['max_iters']: if self.configer.get('local_rank') == 0: RunnerHelper.save_net(self, self.cls_net, postfix='final') break if self.runner_state['iters'] % self.solver_dict[ 'save_iters'] == 0 and self.configer.get( 'local_rank') == 0: RunnerHelper.save_net(self, self.cls_net) # Check to val the current model. if self.runner_state['iters'] % self.solver_dict[ 'test_interval'] == 0: self.val() if self.configer.get('local_rank') == 0: RunnerHelper.save_net( self, self.cls_net, performance=self.runner_state['performance']) self.val() def val(self): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() val_loaders = dict() val_to_end = dict() all_to_end = False for source in self.val_loaders: val_loaders[source] = iter(self.val_loaders[source]) val_to_end[source] = False with torch.no_grad(): while not all_to_end: data_dict = dict() for source in val_loaders: try: tmp_data_dict = next(val_loaders[source]) except StopIteration: val_to_end[source] = True val_loaders[source] = iter(self.val_loaders[source]) tmp_data_dict = next(val_loaders[source]) for k, v in tmp_data_dict.items(): data_dict['src{}_{}'.format(source, k)] = v # Forward pass. data_dict = RunnerHelper.to_device(self, data_dict) out = self.cls_net(data_dict) loss_dict, loss_weight_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) # Compute the loss of the val batch. self.running_score.update(out_dict, label_dict) self.val_losses.update( { key: loss.mean().item() for key, loss in loss_dict.items() }, data_dict['src0_img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() # check whether scan over all data sources all_to_end = True for source in val_to_end: if not val_to_end[source]: all_to_end = False Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) Log.info('TestLoss = {}'.format(self.val_losses.info())) Log.info('TestLossWeight = {}'.format(loss_weight_dict)) Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc())) Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc())) Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc())) top1_acc = yaml.load(self.running_score.get_top1_acc()) for key in top1_acc: if 'src0_label0' in key: self.runner_state['performance'] = top1_acc[key] Log.info('Use acc of {} to compare performace'.format(key)) break self.running_score.reset() self.val_losses.reset() self.batch_time.reset() self.cls_net.train()
class ImageClassifier(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.runner_state = dict(iters=0, last_iters=0, epoch=0, last_epoch=0, performance=0, val_loss=0, max_performance=0, min_val_loss=0) self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = DictAverageMeter() self.val_losses = DictAverageMeter() self.cls_model_manager = ModelManager(configer) self.cls_data_loader = DataLoader(configer) self.running_score = RunningScore(configer) self.cls_net = self.cls_model_manager.get_model() self.solver_dict = self.configer.get( self.configer.get('train', 'solver')) self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict) self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.cls_net, self.optimizer = RunnerHelper.to_dtype( self, self.cls_net, self.optimizer) self.train_loader = self.cls_data_loader.get_trainloader() self.val_loader = self.cls_data_loader.get_valloader() self.loss = self.cls_model_manager.get_loss() def _get_parameters(self): lr_1 = [] lr_2 = [] params_dict = dict(self.cls_net.named_parameters()) for key, value in params_dict.items(): if value.requires_grad: if 'backbone' in key: if self.configer.get('network', 'bb_lr_scale') == 0.0: value.requires_grad = False else: lr_1.append(value) else: lr_2.append(value) params = [{ 'params': lr_1, 'lr': self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale') }, { 'params': lr_2, 'lr': self.solver_dict['lr']['base_lr'] }] return params def run(self): if self.configer.get('network', 'resume_val'): self.val() while self.runner_state['iters'] < self.solver_dict['max_iters']: self.train() self.val() def train(self): """ Train function of every epoch during train phase. """ self.cls_net.train() start_time = time.time() # Adjust the learning rate after every epoch. self.runner_state['epoch'] += 1 for i, data_dict in enumerate(self.train_loader): data_dict = {'src0_{}'.format(k): v for k, v in data_dict.items()} Trainer.update( self, warm_list=(0, ), warm_lr_list=(self.solver_dict['lr']['base_lr'] * self.configer.get('network', 'bb_lr_scale'), ), solver_dict=self.solver_dict) self.data_time.update(time.time() - start_time) data_dict = RunnerHelper.to_device(self, data_dict) # Forward pass. out = self.cls_net(data_dict) loss_dict, _ = self.loss(out) # Compute the loss of the train batch & backward. loss = loss_dict['loss'] self.train_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) self.optimizer.zero_grad() loss.backward() if self.configer.get('network', 'clip_grad'): RunnerHelper.clip_grad(self.cls_net, 10.) self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.runner_state['iters'] += 1 # Print the log info & reset the states. if self.runner_state['iters'] % self.solver_dict[ 'display_iter'] == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {4}\tLoss = {3}\n'.format( self.runner_state['epoch'], self.runner_state['iters'], self.solver_dict['display_iter'], self.train_losses.info(), RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.solver_dict['lr'][ 'metric'] == 'iters' and self.runner_state[ 'iters'] == self.solver_dict['max_iters']: break if self.runner_state['iters'] % self.solver_dict[ 'save_iters'] == 0 and self.configer.get( 'local_rank') == 0: RunnerHelper.save_net(self, self.cls_net) # Check to val the current model. # if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0 \ # and not self.configer.get('distributed'): # self.val() if self.runner_state['iters'] % self.solver_dict[ 'test_interval'] == 0: self.val() def val(self): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() with torch.no_grad(): for j, data_dict in enumerate(self.val_loader): data_dict = { 'src0_{}'.format(k): v for k, v in data_dict.items() } # Forward pass. data_dict = RunnerHelper.to_device(self, data_dict) out = self.cls_net(data_dict) loss_dict = self.loss(out) out_dict, label_dict, _ = RunnerHelper.gather(self, out) self.running_score.update(out_dict, label_dict) self.val_losses.update( {key: loss.item() for key, loss in loss_dict.items()}, data_dict['src0_img'].size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() # RunnerHelper.save_net(self, self.cls_net) # only local_rank=0 can save net # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) Log.info('TestLoss = {}'.format(self.val_losses.info())) Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc())) Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc())) Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc())) self.batch_time.reset() self.batch_time.reset() self.val_losses.reset() self.running_score.reset() self.cls_net.train()
class Trainer(object): """ The class for Pose Estimation. Include train, val, val & predict. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.running_score = RunningScore(configer) self.seg_visualizer = SegVisualizer(configer) self.loss_manager = LossManager(configer) self.module_runner = ModuleRunner(configer) self.model_manager = ModelManager(configer) self.data_loader = DataLoader(configer) self.optim_scheduler = OptimScheduler(configer) self.seg_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self._init_model() def _init_model(self): self.seg_net = self.model_manager.semantic_segmentor() self.seg_net = self.module_runner.load_net(self.seg_net) Log.info('Params Group Method: {}'.format( self.configer.get('optim', 'group_method'))) if self.configer.get('optim', 'group_method') == 'decay': params_group = self.group_weight(self.seg_net) else: assert self.configer.get('optim', 'group_method') is None params_group = self._get_parameters() self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer( params_group) self.train_loader = self.data_loader.get_trainloader() self.val_loader = self.data_loader.get_valloader() self.pixel_loss = self.loss_manager.get_seg_loss() @staticmethod def group_weight(module): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) else: if hasattr(m, 'weight'): group_no_decay.append(m.weight) if hasattr(m, 'bias'): group_no_decay.append(m.bias) assert len(list( module.parameters())) == len(group_decay) + len(group_no_decay) groups = [ dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0) ] return groups def _get_parameters(self): bb_lr = [] nbb_lr = [] params_dict = dict(self.seg_net.named_parameters()) for key, value in params_dict.items(): if 'backbone' not in key: nbb_lr.append(value) else: bb_lr.append(value) params = [{ 'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr') }, { 'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult') }] return params def __train(self): """ Train function of every epoch during train phase. """ self.seg_net.train() start_time = time.time() for i, data_dict in enumerate(self.train_loader): if self.configer.get('lr', 'metric') == 'iters': self.scheduler.step(self.configer.get('iters')) else: self.scheduler.step(self.configer.get('epoch')) if self.configer.get('lr', 'is_warm'): self.module_runner.warm_lr(self.configer.get('iters'), self.scheduler, self.optimizer, backbone_list=[ 0, ]) inputs = data_dict['img'] targets = data_dict['labelmap'] self.data_time.update(time.time() - start_time) # Change the data type. # inputs, targets = self.module_runner.to_device(inputs, targets) # Forward pass. outputs = self.seg_net(inputs) # outputs = self.module_utilizer.gather(outputs) # Compute the loss of the train batch & backward. loss = self.pixel_loss(outputs, targets, gathered=self.configer.get( 'network', 'gathered')) if self.configer.exists('train', 'loader') and self.configer.get( 'train', 'loader') == 'ade20k': batch_size = self.configer.get( 'train', 'batch_size') * self.configer.get( 'train', 'batch_per_gpu') self.train_losses.update(loss.item(), batch_size) else: self.train_losses.update(loss.item(), inputs.size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get( 'solver', 'display_iter') == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n' .format(self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() if self.configer.get('iters') == self.configer.get( 'solver', 'max_iters'): break # Check to val the current model. if self.configer.get('iters') % self.configer.get( 'solver', 'test_interval') == 0: self.__val() self.configer.plus_one('epoch') def __val(self, data_loader=None): """ Validation function during the train phase. """ self.seg_net.eval() start_time = time.time() data_loader = self.val_loader if data_loader is None else data_loader for j, data_dict in enumerate(data_loader): inputs = data_dict['img'] targets = data_dict['labelmap'] with torch.no_grad(): # Change the data type. inputs, targets = self.module_runner.to_device(inputs, targets) # Forward pass. outputs = self.seg_net(inputs) # Compute the loss of the val batch. loss = self.pixel_loss(outputs, targets, gathered=self.configer.get( 'network', 'gathered')) outputs = self.module_runner.gather(outputs) self.val_losses.update(loss.item(), inputs.size(0)) self._update_running_score(outputs[-1], data_dict['meta']) # self.seg_running_score.update(pred.max(1)[1].cpu().numpy(), targets.cpu().numpy()) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.update(['performance'], self.running_score.get_mean_iou()) self.configer.update(['val_loss'], self.val_losses.avg) self.module_runner.save_net(self.seg_net, save_mode='performance') self.module_runner.save_net(self.seg_net, save_mode='val_loss') # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t' 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time, loss=self.val_losses)) Log.info('Mean IOU: {}\n'.format(self.running_score.get_mean_iou())) Log.info('Pixel ACC: {}\n'.format(self.running_score.get_pixel_acc())) self.batch_time.reset() self.val_losses.reset() self.running_score.reset() self.seg_net.train() def _update_running_score(self, pred, metas): pred = pred.permute(0, 2, 3, 1) for i in range(pred.size(0)): ori_img_size = metas[i]['ori_img_size'] border_size = metas[i]['border_size'] ori_target = metas[i]['ori_target'] total_logits = cv2.resize( pred[i, :border_size[1], :border_size[0]].cpu().numpy(), tuple(ori_img_size), interpolation=cv2.INTER_CUBIC) labelmap = np.argmax(total_logits, axis=-1) self.running_score.update(labelmap[None], ori_target[None]) def train(self): # cudnn.benchmark = True if self.configer.get('network', 'resume') is not None and self.configer.get( 'network', 'resume_val'): self.__val() while self.configer.get('iters') < self.configer.get( 'solver', 'max_iters'): self.__train() self.__val(data_loader=self.data_loader.get_valloader(dataset='val')) self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))