def test(self, task_name=''): if self.pred_model is not None: self.pred_model.eval() pos, num_sample = 0, 0 # sample actions acts = self.enumerate_actions() for batch_idx, (data, boxes, gt_acts, labels, _, _, _, _) in enumerate(self.val_loader): batch_size = data.shape[0] num_objs = gt_acts.shape[1] pos_feat = xyxy_to_posf(boxes, data.shape) rois = xyxy_to_rois(boxes, batch_size, data.shape[1], self.num_gpus) gt_rois = boxes.cpu().numpy().copy() pred_acts = np.zeros((batch_size, num_objs, 3)) conf_acts = -np.inf * np.ones((batch_size,)) # if self.random_policy: # for i in range(batch_size): # obj_id = np.random.randint(num_objs) # pred_acts[i, obj_id] = acts[np.random.randint(len(acts))] for idx, (act, obj_id) in enumerate(itertools.product(acts, range(num_objs))): tprint(f'current batch: {idx} / {acts.shape[0] * num_objs}' + ' ' * 10) act_array = torch.zeros((batch_size, num_objs, 3), dtype=torch.float32) act_array[:, obj_id, :] = torch.from_numpy(act) traj_array = self.generate_trajs(data, rois, pos_feat, act_array, boxes) conf = self.get_act_conf(traj_array, gt_rois, obj_id) pred_acts[conf > conf_acts] = act_array[conf > conf_acts] conf_acts[conf > conf_acts] = conf[conf > conf_acts] # for i in range(C.SOLVER.BATCH_SIZE): # plot_image_idx = C.SOLVER.BATCH_SIZE * batch_idx + i # video_idx, img_idx = self.val_loader.dataset.video_info[plot_image_idx] # video_name = self.val_loader.dataset.video_list[video_idx] # search_suffix = self.val_loader.dataset.search_suffix # image_list = sorted(glob(f'{video_name}/{search_suffix}')) # im_name = image_list[img_idx] # video_id, image_id = im_name.split('.')[0].split('/')[-2:] # output_name = f'{video_id}_{image_id}_{idx}' # im_data = get_im_data(im_name, gt_rois[[i], 0:1], C.DATA_ROOT, False) # plt.axis('off') # plt.imshow(im_data[..., ::-1]) # _plot_bbox_traj(traj_array[i], size=160, alpha=1.0) # x = gt_rois[i, 0, obj_id, 0] + self.ball_radius # y = gt_rois[i, 0, obj_id, 1] + self.ball_radius # dy, dx = act[0] * act[2], act[1] * act[2] # plt.arrow(x, y, dx, dy, color=(0.99, 0.99, 0.99), linewidth=5) # os.makedirs(f'{self.output_dir}/plan', exist_ok=True) # kwargs = {'format': 'svg', 'bbox_inches': 'tight', 'pad_inches': 0} # plt.savefig(f'{self.output_dir}/plan/pred_{output_name}.svg', **kwargs) # plt.close() sim_rst, debug_gt_traj_array = self.simulate_action(gt_rois, pred_acts) pos += sim_rst.sum() num_sample += sim_rst.shape[0] pprint(f'{task_name} {batch_idx}/{len(self.val_loader)}: {pos / num_sample:.4f}' + ' ' * 10) pprint(f'{task_name}: {pos / num_sample:.4f}' + ' ' * 10)
def generate_trajs(self, data, boxes): with torch.no_grad(): num_objs = boxes.shape[2] g_idx = np.array([[i, j, 1] for i in range(num_objs) for j in range(num_objs) if j != i]) g_idx = torch.from_numpy(g_idx[None].repeat(data.shape[0], 0)) rois = xyxy_to_rois(boxes, batch=data.shape[0], time_step=data.shape[1], num_devices=self.num_gpus) pos_feat = xyxy_to_posf(rois, data.shape) outputs = self.model(data, rois, pos_feat, num_rollouts=self.pred_rollout, g_idx=g_idx) outputs = { 'boxes': outputs['boxes'].cpu().numpy(), 'masks': outputs['masks'].cpu().numpy(), } outputs['boxes'][..., 0::2] *= self.input_width outputs['boxes'][..., 1::2] *= self.input_height outputs['boxes'] = xywh2xyxy( outputs['boxes'].reshape(-1, 4) ).reshape((data.shape[0], -1, num_objs, 4)) return outputs['boxes'], outputs['masks']
def generate_trajs(self, data, rois, pos_feat, acts, boxes): all_pred_rois = np.zeros((data.shape[0], 0, acts.shape[1], 4)) with torch.no_grad(): data, rois, pos_feat, acts = \ data.to(self.device), rois.to(self.device), pos_feat.to(self.device), acts.to(self.device) if self.oracle_actor: sim_len_backup = self.sim_rollout_length self.sim_rollout_length = self.act_rollout + 1 _, pred_rois = self.simulate_action(rois[..., 1:].cpu().numpy(), acts.cpu().numpy(), return_rst=False) self.sim_rollout_length = sim_len_backup all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1) else: data = data[:, [0]] rois = xyxy_to_rois(boxes, data.shape[0], data.shape[1], self.num_gpus) coor_features = pos_feat[:, [0]] outputs = self.act_model(data, rois, None, act_features=acts, num_rollouts=self.act_rollout) pred_rois = xcyc_to_xyxy(torch.clamp(outputs['bbox'], 0, 1).cpu().numpy()[..., 2:], self.input_height, self.input_width, self.ball_radius) pred_rois = np.concatenate([rois[..., 1:].cpu().numpy(), pred_rois], axis=1) all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1) pred_rois = all_pred_rois[:, -self.input_size:].copy() data = sim_rendering(pred_rois, self.input_height, self.input_width, self.ball_radius) for c in range(3): data[..., c] -= C.INPUT.IMAGE_MEAN[c] data[..., c] /= C.INPUT.IMAGE_STD[c] data = data.permute(0, 1, 4, 2, 3) # data is (batch x time_step x 3 x h x w) boxes = torch.from_numpy(pred_rois.astype(np.float32)) pos_feat = xyxy_to_posf(boxes, data.shape) rois = xyxy_to_rois(boxes, data.shape[0], data.shape[1], self.num_gpus) if self.roi_masking: # expand to (batch x time_step x num_objs x 3 x h x w) data = data[:, :, None].repeat(1, 1, self.num_objs, 1, 1, 1) for b, t, o in itertools.product(range(data.shape[0]), range(data.shape[1]), range(self.num_objs)): box = boxes[b, t, o].numpy() x1, y1 = np.floor([box[0], box[1]]).astype(np.int) x2, y2 = np.ceil([box[2], box[3]]).astype(np.int) data[b, t, o, :, :, :x1] = 0 data[b, t, o, :, :y1, :] = 0 data[b, t, o, :, :, x2:] = 0 data[b, t, o, :, y2:, :] = 0 if self.roi_cropping: data_c = np.zeros((data.shape[0], data.shape[1], self.num_objs,) + data.shape[2:]) for b, t, o in itertools.product(range(data.shape[0]), range(data.shape[1]), range(self.num_objs)): box = boxes[b, t, o].numpy() x_c = 0.5 * (box[0] + box[2]) y_c = 0.5 * (box[1] + box[3]) r = self.roi_crop_r d = 2 * r data_c_ = np.zeros((d, d)) image = data[b, t].cpu().numpy().transpose((1, 2, 0)) image_pad = np.pad(image, ((d, d), (d, d), (0, 0))) if x_c > -r or y_c > -r or x_c < self.input_width + r or y_c < self.input_height + r: x_c += d y_c += d data_c_ = image_pad[int(y_c - r):int(y_c + r), int(x_c - r):int(x_c + r), :] data_c_ = cv2.resize(data_c_, (self.input_width, self.input_height)) data_c[b, t, o] = data_c_.transpose((2, 0, 1)) data = torch.from_numpy(data_c.astype(np.float32)) if self.roi_masking or self.roi_cropping: data = data.permute((0, 2, 1, 3, 4, 5)) data = data.reshape((data.shape[0] * data.shape[1],) + data.shape[2:]) outputs = self.pred_model(data, rois, pos_feat, num_rollouts=self.pred_rollout + self.cons_size) bbox_rollouts = outputs['bbox'].cpu().numpy()[..., 2:] pred_rois = xcyc_to_xyxy(bbox_rollouts, self.input_height, self.input_width, self.ball_radius) pred_rois = pred_rois[:, -(1 + self.pred_rollout):] all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1) return all_pred_rois
def train_epoch(self): for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes, gt_masks, valid, module_valid, g_idx, seq_l, objinfo, gtindicator) in enumerate(self.train_loader): self._adjust_learning_rate() if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING: # data should be (b x t x o x c x h x w) data = data.permute((0, 2, 1, 3, 4, 5)) # (b, o, t, c, h, w) data = data.reshape((data.shape[0] * data.shape[1], ) + data.shape[2:]) # (b*o, t, c, h, w) data, data_t = data.to(self.device), data_t.to(self.device) pos_feat = xyxy_to_posf(rois, data.shape) rois = xyxy_to_rois(rois, batch=data.shape[0], time_step=data.shape[1], num_devices=self.num_gpus) self.optim.zero_grad() outputs = self.model(data, rois, pos_feat, valid, num_rollouts=self.ptrain_size, g_idx=g_idx, x_t=data_t, phase='train') labels = { 'boxes': gt_boxes.to(self.device), 'masks': gt_masks.to(self.device), 'valid': valid.to(self.device), 'module_valid': module_valid.to(self.device), 'seq_l': seq_l.to(self.device), 'gt_indicators': gtindicator.to(self.device), } loss = self.loss(outputs, labels, 'train') loss.backward() self.optim.step() # this is an approximation for printing; the dataset size may not divide the batch size self.iterations += self.batch_size print_msg = "" print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k" print_msg += f" | " mean_loss = np.mean( np.array(self.box_p_step_losses[:self.ptrain_size]) / self.loss_cnt) * 1e3 print_msg += f"{mean_loss:.3f} | " print_msg += f" | ".join([ "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt) for name in self.loss_name ]) print_msg += " || {:.4f}".format(self.loss_ind) speed = self.loss_cnt / (timer() - self.time) eta = (self.max_iters - self.iterations) / speed / 3600 print_msg += f" | speed: {speed:.1f} | eta: {eta:.2f} h" print_msg += ( " " * (os.get_terminal_size().columns - len(print_msg) - 10)) tprint(print_msg) if self.iterations % self.val_interval == 0: self.snapshot() self.val() self._init_loss() self.model.train() if self.iterations >= self.max_iters: print('\r', end='') print(f'{self.best_mean:.3f}') break
def val(self): self.model.eval() self._init_loss() if C.RIN.VAE: losses = dict.fromkeys(self.loss_name, 0.0) box_p_step_losses = [0.0 for _ in range(self.ptest_size)] masks_step_losses = [0.0 for _ in range(self.ptest_size)] for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes, gt_masks, valid, module_valid, g_idx, seq_l, objinfo, gtindicator) in enumerate(self.val_loader): tprint(f'eval: {batch_idx}/{len(self.val_loader)}') with torch.no_grad(): if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING: data = data.permute((0, 2, 1, 3, 4, 5)) data = data.reshape((data.shape[0] * data.shape[1], ) + data.shape[2:]) data = data.to(self.device) pos_feat = xyxy_to_posf(rois, data.shape) rois = xyxy_to_rois(rois, batch=data.shape[0], time_step=data.shape[1], num_devices=self.num_gpus) labels = { 'boxes': gt_boxes.to(self.device), 'masks': gt_masks.to(self.device), 'valid': valid.to(self.device), 'module_valid': module_valid.to(self.device), 'seq_l': seq_l.to(self.device), 'gt_indicators': gtindicator.to(self.device), } outputs = self.model(data, rois, pos_feat, valid, num_rollouts=self.ptest_size, g_idx=g_idx, phase='test') self.loss(outputs, labels, 'test') # VAE multiple runs if C.RIN.VAE: vae_best_mean = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 losses_t = self.losses.copy() box_p_step_losses_t = self.box_p_step_losses.copy() masks_step_losses_t = self.masks_step_losses.copy() for i in range(9): outputs = self.model(data, rois, pos_feat, valid, num_rollouts=self.ptest_size, g_idx=g_idx, phase='test') self.loss(outputs, labels, 'test') mean_loss = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 if mean_loss < vae_best_mean: losses_t = self.losses.copy() box_p_step_losses_t = self.box_p_step_losses.copy() masks_step_losses_t = self.masks_step_losses.copy() vae_best_mean = mean_loss self._init_loss() for k, v in losses.items(): losses[k] += losses_t[k] for i in range(len(box_p_step_losses)): box_p_step_losses[i] += box_p_step_losses_t[i] masks_step_losses[i] += masks_step_losses_t[i] if C.RIN.VAE: self.losses = losses.copy() self.box_p_step_losses = box_p_step_losses.copy() self.loss_cnt = len(self.val_loader) print('\r', end='') print_msg = "" print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k" print_msg += f" | " mean_loss = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 print_msg += f"{mean_loss:.3f} | " if mean_loss < self.best_mean: self.snapshot('ckpt_best.path.tar') self.best_mean = mean_loss print_msg += f" | ".join([ "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt) for name in self.loss_name ]) print_msg += " || {:.4f}".format(self.loss_ind) print_msg += (" " * (os.get_terminal_size().columns - len(print_msg) - 10)) self.logger.info(print_msg)
def test(self): self.model.eval() if C.RIN.VAE: losses = dict.fromkeys(self.loss_name, 0.0) box_p_step_losses = [0.0 for _ in range(self.ptest_size)] masks_step_losses = [0.0 for _ in range(self.ptest_size)] for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes, gt_masks, valid, module_valid, g_idx, _) in enumerate(self.val_loader): with torch.no_grad(): # decide module_valid here for evaluation mid = 0 # ball-only module_valid = module_valid[:, mid, :, :] if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING: # data should be (b x t x o x c x h x w) data = data.permute( (0, 2, 1, 3, 4, 5)) # (b, o, t, c, h, w) data = data.reshape((data.shape[0] * data.shape[1], ) + data.shape[2:]) # (b*o, t, c, h, w) data = data.to(self.device) pos_feat = xyxy_to_posf(rois, data.shape) rois = xyxy_to_rois(rois, batch=data.shape[0], time_step=data.shape[1], num_devices=self.num_gpus) labels = { 'boxes': gt_boxes.to(self.device), 'masks': gt_masks.to(self.device), 'valid': valid.to(self.device), 'module_valid': module_valid.to(self.device), } outputs = self.model(data, rois, pos_feat, valid, num_rollouts=self.ptest_size, g_idx=g_idx, phase='test') # # ********************************************************************************* # # VISUALIZATION - generate input image and GT outputs + model outputs # input_data = data.cpu().detach().numpy() # (128, 1, 3, 128, 128) # gt_data = data_pred.cpu().detach().numpy() # (128, 10, 3, 128, 128) # data_t = data_t.cpu().detach().numpy() # (128, 1, 128, 128) # validity = valid.cpu().detach().numpy() # (128, 6) # outputs_boxes = outputs['boxes'].cpu().detach().numpy() # (128, 10, 6, 4) # outputs_masks = outputs['masks'].cpu().detach().numpy() # (128, 10, 6, 21, 21) # np.save('save/'+str(batch_idx)+'input_data.npy', input_data) # np.save('save/'+str(batch_idx)+'gt_data.npy', gt_data) # np.save('save/'+str(batch_idx)+'data_t.npy', data_t) # np.save('save/'+str(batch_idx)+'validity.npy', validity) # np.save('save/'+str(batch_idx)+'outputs_boxes.npy', outputs_boxes) # np.save('save/'+str(batch_idx)+'outputs_masks.npy', outputs_masks) # # self.visualize_results(input_data, gt_data, outputs, data_t, validity, env_name) # # ********************************************************************************* self.loss(outputs, labels, 'test') # VAE multiple runs if C.RIN.VAE: vae_best_mean = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 losses_t = self.losses.copy() box_p_step_losses_t = self.box_p_step_losses.copy() masks_step_losses_t = self.masks_step_losses.copy() for i in range(99): outputs = self.model(data, rois, None, num_rollouts=self.ptest_size, g_idx=g_idx, phase='test') self.loss(outputs, labels, 'test') mean_loss = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 if mean_loss < vae_best_mean: losses_t = self.losses.copy() box_p_step_losses_t = self.box_p_step_losses.copy() masks_step_losses_t = self.masks_step_losses.copy() vae_best_mean = mean_loss self._init_loss() for k, v in losses.items(): losses[k] += losses_t[k] for i in range(len(box_p_step_losses)): box_p_step_losses[i] += box_p_step_losses_t[i] masks_step_losses[i] += masks_step_losses_t[i] tprint(f'eval: {batch_idx}/{len(self.val_loader)}:' + ' ' * 20) if self.plot_image > 0: outputs = { 'boxes': outputs['boxes'].cpu().numpy(), 'masks': outputs['masks'].cpu().numpy() if C.RIN.MASK_LOSS_WEIGHT else None, } outputs['boxes'][..., 0::2] *= self.input_width outputs['boxes'][..., 1::2] *= self.input_height outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape( -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4)) labels = { 'boxes': labels['boxes'].cpu().numpy(), 'masks': labels['masks'].cpu().numpy(), } labels['boxes'][..., 0::2] *= self.input_width labels['boxes'][..., 1::2] *= self.input_height labels['boxes'] = xywh2xyxy(labels['boxes'].reshape( -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4)) for i in range(rois.shape[0]): batch_size = C.SOLVER.BATCH_SIZE if not C.RIN.VAE else 1 plot_image_idx = batch_size * batch_idx + i if plot_image_idx < self.plot_image: tprint(f'plotting: {plot_image_idx}' + ' ' * 20) video_idx, img_idx = self.val_loader.dataset.video_info[ plot_image_idx] video_name = self.val_loader.dataset.video_list[ video_idx] v = valid[i].numpy().astype(np.bool) pred_boxes_i = outputs['boxes'][i][:, v] gt_boxes_i = labels['boxes'][i][:, v] if 'PHYRE' in C.DATA_ROOT: # [::-1] is to make it consistency with others where opencv is used im_data = phyre.observations_to_float_rgb( np.load(video_name).astype( np.uint8))[..., ::-1] a, b, c = video_name.split('/')[5:8] output_name = f'{a}_{b}_{c.replace(".npy", "")}' bg_image = np.load(video_name).astype(np.uint8) for fg_id in [1, 2, 3, 5]: bg_image[bg_image == fg_id] = 0 bg_image = phyre.observations_to_float_rgb( bg_image) # if f'{a}_{b}' not in [ # '00014_123', '00014_528', '00015_257', '00015_337', '00019_273', '00019_296' # ]: # continue # if f'{a}_{b}' not in [ # '00000_069', '00001_000', '00002_185', '00003_064', '00004_823', # '00005_111', '00006_033', '00007_090', '00008_177', '00009_930', # '00010_508', '00011_841', '00012_071', '00013_074', '00014_214', # '00015_016', '00016_844', '00017_129', '00018_192', '00019_244', # '00020_010', '00021_115', '00022_537', '00023_470', '00024_048' # ]: # continue else: bg_image = None image_list = sorted( glob( f'{video_name}/*{self.val_loader.dataset.image_ext}' )) im_name = image_list[img_idx] output_name = '_'.join( im_name.split('.')[0].split('/')[-2:]) # deal with image data here # plot rollout function only take care of the usage of plt # if output_name not in ['009_015', '009_031', '009_063', '039_038', '049_011', '059_033']: # continue # if output_name not in ['00002_00037', '00008_00047', '00011_00048', '00013_00036', # '00014_00033', '00020_00054', '00021_00013', '00024_00011']: # continue if output_name not in [ '0016_000', '0045_000', '0120_000', '0163_000' ]: continue gt_boxes_i = labels['boxes'][i][:, v] im_data = get_im_data(im_name, gt_boxes_i[None, 0:1], C.DATA_ROOT, self.high_resolution_plot) if self.high_resolution_plot: scale_w = im_data.shape[1] / self.input_width scale_h = im_data.shape[0] / self.input_height pred_boxes_i[..., [0, 2]] *= scale_w pred_boxes_i[..., [1, 3]] *= scale_h gt_boxes_i[..., [0, 2]] *= scale_w gt_boxes_i[..., [1, 3]] *= scale_h pred_masks_i = None if C.RIN.MASK_LOSS_WEIGHT: pred_masks_i = outputs['masks'][i][:, v] plot_rollouts(im_data, pred_boxes_i, gt_boxes_i, pred_masks_i, labels['masks'][i][:, v], output_dir=self.output_dir, output_name=output_name, bg_image=bg_image) if C.RIN.VAE: self.losses = losses.copy() self.box_p_step_losses = box_p_step_losses.copy() self.loss_cnt = len(self.val_loader) print('\r', end='') print_msg = "" mean_loss = np.mean( np.array(self.box_p_step_losses[:self.ptest_size]) / self.loss_cnt) * 1e3 print_msg += f"{mean_loss:.3f} | " print_msg += f" | ".join([ "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt) for name in self.loss_name ]) pprint(print_msg)