def train_epoch(self, train_db, epoch): syn_db = synthesis_loader(train_db) loader = DataLoader(syn_db, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers, pin_memory=True) errors_list = [] if self.cfg.cuda and self.cfg.parallel: net = self.net.module else: net = self.net self.net.train() net.lossnet.eval() for cnt, batched in enumerate(loader): ################################################################## ## Batched data ################################################################## proposals, gt_images, gt_labels = self.batch_data(batched) gt_images = gt_images.permute(0, 3, 1, 2) weights = None if self.cfg.weighted_synthesis: weights = proposals[:, :, :, -4].clone().detach() weights = 0.5 * (1.0 + weights) ################################################################## ## Train one step ################################################################## self.net.zero_grad() synthesized_images, synthesized_labels, synthesized_features, gt_features = \ self.net(proposals, True, gt_images) loss, losses = self.compute_loss(synthesized_images, gt_images, synthesized_features, gt_features, synthesized_labels, gt_labels, weights) loss.backward() self.optimizer.step() ################################################################## ## Collect info ################################################################## errors_list.append(losses.cpu().data.numpy().flatten()) ################################################################## ## Print info ################################################################## if cnt % self.cfg.log_per_steps == 0: tmp = np.stack(errors_list, 0) print('Epoch %03d, iter %07d:' % (epoch, cnt)) print(np.mean(tmp[:, 0]), np.mean(tmp[:, 1]), np.mean(tmp[:, 2])) print(np.mean(tmp[:, 3]), np.mean(tmp[:, 4]), np.mean(tmp[:, 5]), np.mean(tmp[:, 6]), np.mean(tmp[:, 7])) print('-------------------------') return np.array(errors_list)
def test_syn_model(config): synthesizer = SynthesisModel(config) print(get_n_params(synthesizer)) db = coco(config, 'train', '2017') syn_loader = synthesis_loader(db) loader = DataLoader(syn_loader, batch_size=1, shuffle=False, num_workers=config.num_workers) start = time() for cnt, batched in enumerate(loader): x = batched['input_vol'].float() y = batched['gt_image'].float() z = batched['gt_label'].long() y = y.permute(0, 3, 1, 2) image, label, syn_feats, gt_feats = synthesizer(x, True, y) print(image.size(), label.size()) for v in syn_feats: print(v.size()) print('------------') for v in gt_feats: print(v.size()) break
def test_syn_encoder(config): img_encoder = SynthesisEncoder(config) print(get_n_params(img_encoder)) db = coco(config, 'train', '2017') syn_loader = synthesis_loader(db) loader = DataLoader(syn_loader, batch_size=1, shuffle=False, num_workers=config.num_workers) start = time() for cnt, batched in enumerate(loader): x = batched['input_vol'].float() y = img_encoder(x) for z in y: print(z.size()) break
def test_perceptual_loss_network(config): img_encoder = VGG19LossNetwork(config).eval() print(get_n_params(img_encoder)) db = coco(config, 'train', '2017') syn_loader = synthesis_loader(db) loader = DataLoader(syn_loader, batch_size=1, shuffle=False, num_workers=config.num_workers) start = time() for cnt, batched in enumerate(loader): x = batched['gt_image'].float() y = img_encoder(x.permute(0, 3, 1, 2)) for z in y: print(z.size()) break
def test_syn_decoder(config): img_encoder = SynthesisEncoder(config) img_decoder = SynthesisDecoder(config) print(get_n_params(img_encoder)) print(get_n_params(img_decoder)) db = coco(config, 'train', '2017') syn_loader = synthesis_loader(db) loader = DataLoader(syn_loader, batch_size=1, shuffle=False, num_workers=config.num_workers) start = time() for cnt, batched in enumerate(loader): x = batched['input_vol'].float() x0, x1, x2, x3, x4, x5, x6 = img_encoder(x) inputs = (x0, x1, x2, x3, x4, x5, x6) image, label = img_decoder(inputs) print(image.size(), label.size()) break
def test_syn_dataloader(config): db = coco(config, 'train', '2017') syn_loader = synthesis_loader(db) output_dir = osp.join(config.model_dir, 'test_syn_dataloader') maybe_create(output_dir) loader = DataLoader(syn_loader, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) start = time() for cnt, batched in enumerate(loader): x = batched['input_vol'].float() y = batched['gt_image'].float() z = batched['gt_label'].float() if config.use_color_volume: x = batch_color_volumn_preprocess(x, len(db.classes)) else: x = batch_onehot_volumn_preprocess(x, len(db.classes)) print('input_vol', x.size()) print('gt_image', y.size()) print('gt_label', z.size()) # cv2.imwrite('mask0.png', x[0,:,:,-4].cpu().data.numpy()) # cv2.imwrite('mask1.png', x[1,:,:,-4].cpu().data.numpy()) # cv2.imwrite('mask2.png', x[2,:,:,-4].cpu().data.numpy()) # cv2.imwrite('mask3.png', x[3,:,:,-4].cpu().data.numpy()) # cv2.imwrite('label0.png', x[0,:,:,3].cpu().data.numpy()) # cv2.imwrite('label1.png', x[1,:,:,3].cpu().data.numpy()) # cv2.imwrite('label2.png', x[2,:,:,3].cpu().data.numpy()) # cv2.imwrite('label3.png', x[3,:,:,3].cpu().data.numpy()) # cv2.imwrite('color0.png', x[0,:,:,-3:].cpu().data.numpy()) # cv2.imwrite('color1.png', x[1,:,:,-3:].cpu().data.numpy()) # cv2.imwrite('color2.png', x[2,:,:,-3:].cpu().data.numpy()) # cv2.imwrite('color3.png', x[3,:,:,-3:].cpu().data.numpy()) x = (x - 128.0).permute(0, 3, 1, 2) plt.switch_backend('agg') x = tensors_to_vols(x) for i in range(x.shape[0]): image_idx = batched['image_index'][i] name = '%03d_' % i + str(image_idx).zfill(12) out_path = osp.join(output_dir, name + '.png') if config.use_color_volume: proposal = x[i, :, :, 12:15] mask = x[i, :, :, :3] person = x[i, :, :, 9:12] other = x[i, :, :, 15:18] gt_color = y[i] gt_label = z[i] gt_label = np.repeat(gt_label[..., None], 3, -1) else: proposal = x[i, :, :, -3:] mask = x[i, :, :, -4] mask = np.repeat(mask[..., None], 3, -1) person = x[i, :, :, 3] person = np.repeat(person[..., None], 3, -1) other = x[i, :, :, 5] other = np.repeat(other[..., None], 3, -1) gt_color = y[i] gt_label = z[i] gt_label = np.repeat(gt_label[..., None], 3, -1) r1 = np.concatenate((proposal, mask, person), 1) r2 = np.concatenate((gt_color, gt_label, other), 1) out = np.concatenate((r1, r2), 0).astype(np.uint8) fig = plt.figure(figsize=(32, 32)) plt.imshow(out[:, :, :]) plt.axis('off') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) if cnt == 1: break print("Time", time() - start)
def sample_for_vis(self, epoch, test_db, N, random_or_not=False): ############################################################## # Output prefix ############################################################## plt.switch_backend('agg') output_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'vis') maybe_create(output_dir) ############################################################## # Main loop ############################################################## syn_db = synthesis_loader(test_db) loader = DataLoader(syn_db, batch_size=self.cfg.batch_size, shuffle=random_or_not, pin_memory=True) if self.cfg.cuda and self.cfg.parallel: net = self.net.module else: net = self.net max_cnt = min(N, len(test_db.scenedb)) self.net.eval() for cnt, batched in enumerate(loader): ################################################################## ## Batched data ################################################################## proposals, gt_images, gt_labels = self.batch_data(batched) image_indices = batched['image_index'].cpu().data.numpy() ################################################################## ## Train one step ################################################################## with torch.no_grad(): synthesized_images, synthesized_labels, _, _ = \ self.net(proposals, False, None) for i in range(synthesized_images.size(0)): synthesized_image = synthesized_images[i].cpu().data.numpy() synthesized_image = synthesized_image.transpose((1, 2, 0)) gt_image = gt_images[i].cpu().data.numpy() synthesized_label = torch.max(synthesized_labels[i], 0)[-1] synthesized_label = synthesized_label.cpu().data.numpy() synthesized_label = test_db.decode_semantic_map( synthesized_label) gt_label = gt_labels[i].cpu().data.numpy() gt_label = test_db.decode_semantic_map(gt_label) fig = plt.figure(figsize=(32, 32)) plt.subplot(2, 2, 1) plt.imshow( clamp_array(synthesized_image, 0, 255).astype(np.uint8)) plt.axis('off') plt.subplot(2, 2, 2) plt.imshow(clamp_array(gt_image, 0, 255).astype(np.uint8)) plt.axis('off') plt.subplot(2, 2, 3) plt.imshow( clamp_array(synthesized_label, 0, 255).astype(np.uint8)) plt.axis('off') plt.subplot(2, 2, 4) plt.imshow(clamp_array(gt_label, 0, 255).astype(np.uint8)) plt.axis('off') image_idx = image_indices[i] name = '%03d_' % cnt + str(image_idx).zfill(12) out_path = osp.join(output_dir, name + '.png') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) print('sampling: %d, %d, %d' % (epoch, cnt, i)) if (cnt + 1) * self.cfg.batch_size >= max_cnt: break