class Trainer(object): TrainParams = TrainParams # hooks on_start_epoch_hooks = [] on_end_epoch_hooks = [] def __init__(self, model, train_params, batch_processor, train_data, val_data=None): assert isinstance(train_params, TrainParams) self.params = train_params # Data loaders self.train_data = train_data self.val_data = val_data # sDataLoader.copy(val_data) if isinstance(val_data, DataLoader) else val_data # self.val_stream = self.val_data.get_stream() if self.val_data else None self.batch_processor = batch_processor self.batch_per_epoch = len(self.train_data) # set CUDA_VISIBLE_DEVICES=gpus gpus = ','.join([str(x) for x in self.params.gpus]) os.environ['CUDA_VISIBLE_DEVICES'] = gpus self.params.gpus = tuple(range(len(self.params.gpus))) logger.info('Set CUDA_VISIBLE_DEVICES to {}...'.format(gpus)) # Optimizer and learning rate self.last_epoch = 0 self.optimizer = self.params.optimizer # type: Optimizer if not isinstance(self.optimizer, Optimizer): logger.error('optimizer should be an instance of Optimizer, ' 'but got {}'.format(type(self.optimizer))) raise ValueError self.lr_scheduler = self.params.lr_scheduler # type: ReduceLROnPlateau or _LRScheduler if self.lr_scheduler and not isinstance( self.lr_scheduler, (ReduceLROnPlateau, _LRScheduler)): logger.error( 'lr_scheduler should be an instance of _LRScheduler or ReduceLROnPlateau, ' 'but got {}'.format(type(self.lr_scheduler))) raise ValueError logger.info('Set lr_scheduler to {}'.format(type(self.lr_scheduler))) self.log_values = OrderedDict() self.batch_timer = Timer() self.data_timer = Timer() # load model self.model = model ckpt = self.params.ckpt if not self.params.save_dir: self.params.save_dir = os.path.join('outputs', self.params.exp_name) mkdir(self.params.save_dir) logger.info('Set output dir to {}'.format(self.params.save_dir)) if ckpt is None: # find the last ckpt ckpts = [ fname for fname in os.listdir(self.params.save_dir) if os.path.splitext(fname)[-1] == '.h5' ] ckpt = os.path.join( self.params.save_dir, sorted(ckpts, key=lambda name: int( os.path.splitext(name)[0].split('_')[-1])) [-1]) if len(ckpts) > 0 else None if ckpt is not None and not self.params.re_init: self._load_ckpt(ckpt) logger.info('Load ckpt from {}'.format(ckpt)) self.model = ListDataParallel(self.model, device_ids=self.params.gpus) self.model = self.model.cuda(self.params.gpus[0]) self.model.train() if self.params.subnet_name != 'keypoint_subnet': self.model.module.freeze_bn( ) # nn.BatchNorm2d.eval() if not 'keypoint_subnet' def train(self): best_loss = np.inf for epoch in range(self.last_epoch, self.params.max_epoch): self.last_epoch += 1 logger.info('Start training epoch {}'.format(self.last_epoch)) for fun in self.on_start_epoch_hooks: fun(self) # adjust learning rate if isinstance(self.lr_scheduler, _LRScheduler): cur_lrs = get_learning_rates(self.optimizer) self.lr_scheduler.step(self.last_epoch) logger.info('Set learning rates from {} to {}'.format( cur_lrs, get_learning_rates(self.optimizer))) train_loss = self._train_one_epoch() for fun in self.on_end_epoch_hooks: fun(self) # save model if (self.last_epoch % self.params.save_freq_epoch == 0) or (self.last_epoch == self.params.max_epoch - 1): save_name = 'ckpt_{}.h5'.format(self.last_epoch) save_to = os.path.join(self.params.save_dir, save_name) self._save_ckpt(save_to) # find best model if self.params.val_nbatch_end_epoch > 0: val_loss = self._val_one_epoch( self.params.val_nbatch_end_epoch) if val_loss < best_loss: best_file = os.path.join( self.params.save_dir, 'ckpt_{}_{:.5f}.h5.best'.format( self.last_epoch, val_loss)) shutil.copyfile(save_to, best_file) logger.info('Found a better ckpt ({:.5f} -> {:.5f}), ' 'saved to {}'.format( best_loss, val_loss, best_file)) best_loss = val_loss if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(val_loss, self.last_epoch) def _save_ckpt(self, save_to): model = self.model.module if isinstance( self.model, nn.DataParallel) else self.model net_utils.save_net(save_to, model, epoch=self.last_epoch, optimizers=[self.optimizer], rm_prev_opt=True, max_n_ckpts=self.params.save_nckpt_max) logger.info('Save ckpt to {}'.format(save_to)) def _load_ckpt(self, ckpt): epoch, state_dicts = net_utils.load_net(ckpt, self.model, load_state_dict=True) if not self.params.ignore_opt_state and not self.params.zero_epoch and epoch >= 0: self.last_epoch = epoch logger.info('Set last epoch to {}'.format(self.last_epoch)) if state_dicts is not None: self.optimizer.load_state_dict(state_dicts[0]) net_utils.set_optimizer_state_devices(self.optimizer.state, self.params.gpus[0]) logger.info('Load optimizer state from checkpoint, ' 'new learning rate: {}'.format( get_learning_rates(self.optimizer))) def _train_one_epoch(self): self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() total_loss = meter_utils.AverageValueMeter() for step, batch in enumerate(self.train_data): inputs, gts, _ = self.batch_processor(self, batch) self.data_timer.toc() # forward output, saved_for_loss = self.model(*inputs) loss, saved_for_log = self.model.module.build_loss( saved_for_loss, *gts) # backward self.optimizer.zero_grad() loss.backward() total_loss.add(loss.item()) # clip grad if not np.isinf(self.params.max_grad_norm): max_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.params.max_grad_norm, float('inf')) saved_for_log['max_grad'] = max_norm self.optimizer.step(None) self._process_log(saved_for_log, self.log_values) self.batch_timer.toc() # print log reset = False if step % self.params.print_freq == 0: self._print_log(step, self.log_values, title='Training', max_n_batch=self.batch_per_epoch) reset = True if step % self.params.save_freq_step == 0 and step > 0: save_to = os.path.join( self.params.save_dir, 'ckpt_{}.h5.ckpt'.format((self.last_epoch - 1) * self.batch_per_epoch + step)) self._save_ckpt(save_to) if reset: self._reset_log(self.log_values) self.data_timer.tic() self.batch_timer.tic() total_loss, std = total_loss.value() return total_loss def _val_one_epoch(self, n_batch): training_mode = self.model.training self.model.eval() logs = OrderedDict() sum_loss = meter_utils.AverageValueMeter() logger.info('Val on validation set...') self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() for step, batch in enumerate(self.val_data): self.data_timer.toc() if step > n_batch: break inputs, gts, _ = self.batch_processor(self, batch) _, saved_for_loss = self.model(*inputs) self.batch_timer.toc() loss, saved_for_log = self.model.module.build_loss( saved_for_loss, *gts) sum_loss.add(loss.item()) self._process_log(saved_for_log, logs) if step % self.params.print_freq == 0 or step == len( self.val_data) - 1: self._print_log(step, logs, 'Validation', max_n_batch=min(n_batch, len(self.val_data))) self.data_timer.tic() self.batch_timer.tic() mean, std = sum_loss.value() logger.info('Validation loss: mean: {}, std: {}'.format(mean, std)) self.model.train(mode=training_mode) if self.params.subnet_name != 'keypoint_subnet': self.model.module.freeze_bn() return mean def _process_log(self, src_dict, dest_dict): for k, v in src_dict.items(): if isinstance(v, (int, float)): dest_dict.setdefault(k, meter_utils.AverageValueMeter()) dest_dict[k].add(float(v)) else: dest_dict[k] = v def _print_log(self, step, log_values, title='', max_n_batch=None): log_str = '{}\n'.format(self.params.exp_name) log_str += '{}: epoch {}'.format(title, self.last_epoch) if max_n_batch: log_str += '[{}/{}], lr: {}'.format( step, max_n_batch, get_learning_rates(self.optimizer)) i = 0 # global_step = step + (self.last_epoch - 1) * self.batch_per_epoch for k, v in log_values.items(): if isinstance(v, meter_utils.AverageValueMeter): mean, std = v.value() log_str += '\n\t{}: {:.10f}'.format(k, mean) i += 1 if max_n_batch: # print time data_time = self.data_timer.duration + 1e-6 batch_time = self.batch_timer.duration + 1e-6 rest_seconds = int((max_n_batch - step) * batch_time) log_str += '\n\t({:.2f}/{:.2f}s,' \ ' fps:{:.1f}, rest: {})'.format(data_time, batch_time, self.params.batch_size / batch_time, str(datetime.timedelta(seconds=rest_seconds))) self.batch_timer.clear() self.data_timer.clear() logger.info(log_str) def _reset_log(self, log_values): for k, v in log_values.items(): if isinstance(v, meter_utils.AverageValueMeter): v.reset()
class Tester(object): TestParams = TestParams def __init__(self, model, train_params, batch_processor=None, val_data=None): assert isinstance(train_params, TestParams) self.params = train_params self.batch_timer = Timer() self.data_timer = Timer() self.val_data = val_data if val_data else None self.batch_processor = batch_processor if batch_processor else None # load model self.model = model ckpt = self.params.ckpt if ckpt is not None: self._load_ckpt(ckpt) logger.info('Load ckpt from {}'.format(ckpt)) self.model = nn.DataParallel(self.model, device_ids=self.params.gpus) self.model = self.model.cuda(device=self.params.gpus[0]) self.model.eval() self.model.module.freeze_bn() def coco_eval(self): coco_val = os.path.join(self.params.coco_root, 'annotations/person_keypoints_val2017.json') coco = COCO(coco_val) img_ids = coco.getImgIds(catIds=[1]) multipose_results = [] coco_order = [0, 14, 13, 16, 15, 4, 1, 5, 2, 6, 3, 10, 7, 11, 8, 12, 9] for img_id in tqdm(img_ids): img_name = coco.loadImgs(img_id)[0]['file_name'] img = cv2.imread(os.path.join(self.params.coco_root, 'images/val2017/', img_name)).astype(np.float32) shape_dst = np.max(img.shape) scale = float(shape_dst) / self.params.inp_size pad_size = np.abs(img.shape[1] - img.shape[0]) img_resized = np.pad(img, ([0, pad_size], [0, pad_size], [0, 0]), 'constant')[:shape_dst, :shape_dst] img_resized = cv2.resize(img_resized, (self.params.inp_size, self.params.inp_size)) img_input = resnet_preprocess(img_resized) img_input = torch.from_numpy(np.expand_dims(img_input, 0)) with torch.no_grad(): img_input = img_input.cuda(device=self.params.gpus[0]) heatmaps, [scores, classification, transformed_anchors] = self.model([img_input, self.params.subnet_name]) heatmaps = heatmaps.cpu().detach().numpy() heatmaps = np.squeeze(heatmaps, 0) heatmaps = np.transpose(heatmaps, (1, 2, 0)) heatmap_max = np.max(heatmaps[:, :, :17], 2) # segment_map = heatmaps[:, :, 17] param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5} joint_list = get_joint_list(img_resized, param, heatmaps[:, :, :17], scale) del img_resized # bounding box from retinanet scores = scores.cpu().detach().numpy() classification = classification.cpu().detach().numpy() transformed_anchors = transformed_anchors.cpu().detach().numpy() idxs = np.where(scores > 0.5) bboxs=[] for j in range(idxs[0].shape[0]): bbox = transformed_anchors[idxs[0][j], :]*scale if int(classification[idxs[0][j]]) == 0: # class0=people bboxs.append(bbox.tolist()) prn_result = self.prn_process(joint_list.tolist(), bboxs, img_name, img_id) for result in prn_result: keypoints = result['keypoints'] coco_keypoint = [] for i in range(17): coco_keypoint.append(keypoints[coco_order[i] * 3]) coco_keypoint.append(keypoints[coco_order[i] * 3 + 1]) coco_keypoint.append(keypoints[coco_order[i] * 3 + 2]) result['keypoints'] = coco_keypoint multipose_results.append(result) ann_filename = self.params.coco_result_filename with open(ann_filename, "w") as f: json.dump(multipose_results, f, indent=4) # load results in COCO evaluation tool coco_pred = coco.loadRes(ann_filename) # run COCO evaluation coco_eval = COCOeval(coco, coco_pred, 'keypoints') coco_eval.params.imgIds = img_ids coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() if not self.params.testresult_write_json: os.remove(ann_filename) def test(self): img_list = os.listdir(self.params.testdata_dir) multipose_results = [] for img_name in tqdm(img_list): img = cv2.imread(os.path.join(self.params.testdata_dir, img_name)).astype(np.float32) shape_dst = np.max(img.shape) scale = float(shape_dst) / self.params.inp_size pad_size = np.abs(img.shape[1] - img.shape[0]) img_resized = np.pad(img, ([0, pad_size], [0, pad_size], [0, 0]), 'constant')[:shape_dst, :shape_dst] img_resized = cv2.resize(img_resized, (self.params.inp_size, self.params.inp_size)) img_input = resnet_preprocess(img_resized) img_input = torch.from_numpy(np.expand_dims(img_input, 0)) with torch.no_grad(): img_input = img_input.cuda(device=self.params.gpus[0]) heatmaps, [scores, classification, transformed_anchors] = self.model([img_input, self.params.subnet_name]) heatmaps = heatmaps.cpu().detach().numpy() heatmaps = np.squeeze(heatmaps, 0) heatmaps = np.transpose(heatmaps, (1, 2, 0)) heatmap_max = np.max(heatmaps[:, :, :17], 2) # segment_map = heatmaps[:, :, 17] param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5} joint_list = get_joint_list(img_resized, param, heatmaps[:, :, :17], scale) del img_resized # bounding box from retinanet scores = scores.cpu().detach().numpy() classification = classification.cpu().detach().numpy() transformed_anchors = transformed_anchors.cpu().detach().numpy() idxs = np.where(scores > 0.5) bboxs=[] for j in range(idxs[0].shape[0]): bbox = transformed_anchors[idxs[0][j], :]*scale if int(classification[idxs[0][j]]) == 0: # class0=people bboxs.append(bbox.tolist()) prn_result = self.prn_process(joint_list.tolist(), bboxs, img_name) for result in prn_result: multipose_results.append(result) if self.params.testresult_write_image: canvas = plot_result(img, prn_result) cv2.imwrite(os.path.join(self.params.testresult_dir, img_name.split('.', 1)[0] + '_1heatmap.png'), heatmap_max * 256) cv2.imwrite(os.path.join(self.params.testresult_dir, img_name.split('.', 1)[0] + '_2canvas.png'), canvas) if self.params.testresult_write_json: with open(self.params.testresult_dir+'multipose_results.json', "w") as f: json.dump(multipose_results, f) def prn_process(self, kps, bbox_list, file_name, image_id=0): prn_result = [] idx = 0 ks = [] for j in range(17): # joint type t = [] for k in kps: if k[-1] == j: # joint type x = k[0] y = k[1] v = 1 # k[2] if v > 0: t.append([x, y, 1, idx]) idx += 1 ks.append(t) peaks = ks w = int(18 * self.params.coeff) h = int(28 * self.params.coeff) bboxes = [] for bbox_item in bbox_list: bboxes.append([bbox_item[0], bbox_item[1], bbox_item[2]-bbox_item[0], bbox_item[3]-bbox_item[1]]) if len(bboxes) == 0 or len(peaks) == 0: return prn_result weights_bbox = np.zeros((len(bboxes), h, w, 4, 17)) for joint_id, peak in enumerate(peaks): # joint_id: which joint for instance_id, instance in enumerate(peak): # instance_id: which people p_x = instance[0] p_y = instance[1] for bbox_id, b in enumerate(bboxes): is_inside = p_x > b[0] - b[2] * self.params.in_thres and \ p_y > b[1] - b[3] * self.params.in_thres and \ p_x < b[0] + b[2] * (1.0 + self.params.in_thres) and \ p_y < b[1] + b[3] * (1.0 + self.params.in_thres) if is_inside: x_scale = float(w) / math.ceil(b[2]) y_scale = float(h) / math.ceil(b[3]) x0 = int((p_x - b[0]) * x_scale) y0 = int((p_y - b[1]) * y_scale) if x0 >= w and y0 >= h: x0 = w - 1 y0 = h - 1 elif x0 >= w: x0 = w - 1 elif y0 >= h: y0 = h - 1 elif x0 < 0 and y0 < 0: x0 = 0 y0 = 0 elif x0 < 0: x0 = 0 elif y0 < 0: y0 = 0 p = 1e-9 weights_bbox[bbox_id, y0, x0, :, joint_id] = [1, instance[2], instance[3], p] old_weights_bbox = np.copy(weights_bbox) for j in range(weights_bbox.shape[0]): for t in range(17): weights_bbox[j, :, :, 0, t] = gaussian(weights_bbox[j, :, :, 0, t]) output_bbox = [] for j in range(weights_bbox.shape[0]): inp = weights_bbox[j, :, :, 0, :] input = torch.from_numpy(np.expand_dims(inp, axis=0)).cuda().float() output, _ = self.model([input, 'prn_subnet']) temp = np.reshape(output.data.cpu().numpy(), (56, 36, 17)) output_bbox.append(temp) output_bbox = np.array(output_bbox) keypoints_score = [] for t in range(17): indexes = np.argwhere(old_weights_bbox[:, :, :, 0, t] == 1) keypoint = [] for i in indexes: cr = crop(output_bbox[i[0], :, :, t], (i[1], i[2]), N=15) score = np.sum(cr) kp_id = old_weights_bbox[i[0], i[1], i[2], 2, t] kp_score = old_weights_bbox[i[0], i[1], i[2], 1, t] p_score = old_weights_bbox[i[0], i[1], i[2], 3, t] ## ?? bbox_id = i[0] score = kp_score * score s = [kp_id, bbox_id, kp_score, score] keypoint.append(s) keypoints_score.append(keypoint) bbox_keypoints = np.zeros((weights_bbox.shape[0], 17, 3)) bbox_ids = np.arange(len(bboxes)).tolist() # kp_id, bbox_id, kp_score, my_score for i in range(17): joint_keypoints = keypoints_score[i] if len(joint_keypoints) > 0: # if have output result in one type keypoint kp_ids = list(set([x[0] for x in joint_keypoints])) table = np.zeros((len(bbox_ids), len(kp_ids), 4)) for b_id, bbox in enumerate(bbox_ids): for k_id, kp in enumerate(kp_ids): own = [x for x in joint_keypoints if x[0] == kp and x[1] == bbox] if len(own) > 0: table[bbox, k_id] = own[0] else: table[bbox, k_id] = [0] * 4 for b_id, bbox in enumerate(bbox_ids): # all bbx, from 0 to ... row = np.argsort(-table[bbox, :, 3]) # in bbx(bbox), sort from big to small, keypoint score if table[bbox, row[0], 3] > 0: # score for r in row: # all keypoints if table[bbox, r, 3] > 0: column = np.argsort( -table[:, r, 3]) # sort all keypoints r, from big to small, bbx score if bbox == column[0]: # best bbx. best keypoint bbox_keypoints[bbox, i, :] = [x[:3] for x in peaks[i] if x[3] == table[bbox, r, 0]][ 0] break else: # for bbx column[0], the worst keypoint is row2[0], row2 = np.argsort(table[column[0], :, 3]) if row2[0] == r: bbox_keypoints[bbox, i, :] = \ [x[:3] for x in peaks[i] if x[3] == table[bbox, r, 0]][0] break else: # len(joint_keypoints) == 0: for j in range(weights_bbox.shape[0]): b = bboxes[j] x_scale = float(w) / math.ceil(b[2]) y_scale = float(h) / math.ceil(b[3]) for t in range(17): indexes = np.argwhere(old_weights_bbox[j, :, :, 0, t] == 1) if len(indexes) == 0: max_index = np.argwhere(output_bbox[j, :, :, t] == np.max(output_bbox[j, :, :, t])) bbox_keypoints[j, t, :] = [max_index[0][1] / x_scale + b[0], max_index[0][0] / y_scale + b[1], 0] my_keypoints = [] for i in range(bbox_keypoints.shape[0]): k = np.zeros(51) k[0::3] = bbox_keypoints[i, :, 0] k[1::3] = bbox_keypoints[i, :, 1] k[2::3] = bbox_keypoints[i, :, 2] pose_score = 0 count = 0 for f in range(17): if bbox_keypoints[i, f, 0] != 0 and bbox_keypoints[i, f, 1] != 0: count += 1 pose_score += bbox_keypoints[i, f, 2] pose_score /= 17.0 my_keypoints.append(k) image_data = { 'image_id': image_id, 'file_name': file_name, 'category_id': 1, 'bbox': bboxes[i], 'score': pose_score, 'keypoints': k.tolist() } prn_result.append(image_data) return prn_result def val(self): self.model.eval() logs = OrderedDict() sum_loss = meter_utils.AverageValueMeter() logger.info('Val on validation set...') self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() for step, batch in enumerate(self.val_data): self.data_timer.toc() inputs, gts, _ = self.batch_processor(self, batch) _, saved_for_loss = self.model(*inputs) self.batch_timer.toc() loss, saved_for_log = self.model.module.build_loss(saved_for_loss, *gts) sum_loss.add(loss.item()) self._process_log(saved_for_log, logs) if step % self.params.print_freq == 0: self._print_log(step, logs, 'Validation', max_n_batch=len(self.val_data)) self.data_timer.tic() self.batch_timer.tic() mean, std = sum_loss.value() logger.info('\n\nValidation loss: mean: {}, std: {}'.format(mean, std)) def _load_ckpt(self, ckpt): _, _ = net_utils.load_net(ckpt, self.model, load_state_dict=True) def _process_log(self, src_dict, dest_dict): for k, v in src_dict.items(): if isinstance(v, (int, float)): dest_dict.setdefault(k, meter_utils.AverageValueMeter()) dest_dict[k].add(float(v)) else: dest_dict[k] = v def _print_log(self, step, log_values, title='', max_n_batch=None): log_str = '{}\n'.format(self.params.exp_name) log_str += '{}: epoch {}'.format(title, 0) log_str += '[{}/{}]'.format(step, max_n_batch) i = 0 for k, v in log_values.items(): if isinstance(v, meter_utils.AverageValueMeter): mean, std = v.value() log_str += '\n\t{}: {:.10f}'.format(k, mean) i += 1 if max_n_batch: # print time data_time = self.data_timer.duration + 1e-6 batch_time = self.batch_timer.duration + 1e-6 rest_seconds = int((max_n_batch - step) * batch_time) log_str += '\n\t({:.2f}/{:.2f}s,' \ ' fps:{:.1f}, rest: {})'.format(data_time, batch_time, self.params.batch_size / batch_time, str(datetime.timedelta(seconds=rest_seconds))) self.batch_timer.clear() self.data_timer.clear() logger.info(log_str)