def im_show_1(image=None, title='', rows=1, cols=1, index=1): """ Show image (transfer tensor to numpy first) """ image = normalize_image(image, forward=False) image = image.permute(1, 2, 0).cpu().numpy() ax = plt.subplot(rows, cols, index) ax.set_title(title) ax.imshow(image.clip(0, 1)) return ax
def vis_feature(vis, model, dataloader, use_cuda=True): # Visualize feature map of watch image h = 40 num = 1024 id = 4 for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) theta, feature_A, feature_B, correlation = model(batch) if batch_idx == id: break watch_feature_A = F.interpolate(feature_A, size=(h, h), mode='bilinear', align_corners=True).transpose(0, 1)[0:num, :, :, :] watch_feature_B = F.interpolate(feature_B, size=(h, h), mode='bilinear', align_corners=True).transpose(0, 1)[0:num, :, :, :] opts = dict(jpgquality=100, title='source image') image_A = normalize_image(batch['source_image'][0], forward=False) * 255.0 vis.image(image_A, opts=opts) nrow = 32 padding = 3 opts = dict(jpgquality=100, title='feature map A') # vis.images(watch_feature_A * 255.0, nrow=nrow, padding=padding, opts=opts) vis.image(torchvision.utils.make_grid(watch_feature_A * 255.0, nrow=nrow, padding=padding), opts=opts) # vis.image(watch_feature_A[0], opts=opts) opts = dict(jpgquality=100, title='target image') image_B = normalize_image(batch['target_image'][0], forward=False) * 255.0 vis.image(image_B, opts=opts) opts = dict(jpgquality=100, title='feature map B') vis.image(torchvision.utils.make_grid(watch_feature_B * 255.0, nrow=nrow, padding=padding), opts=opts) # vis.images(watch_feature_B * 255.0, nrow=nrow, padding=padding, opts=opts) # vis.image(watch_feature_B[0], opts=opts) # opts = dict(title='correlation') # vis.heatmap(correlation[0, 0, :, :], opts=opts)
def draw_image(images, names, flows): images[:, :, 0:240, :] = normalize_image(images[:, :, 0:240, :], forward=False) images *= 255.0 images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(images.shape[0]): pos_name = (80, 255) if (i + 1) % group_size == 1 or (i + 1) % group_size == 0: pos_flow = (0, 0) else: pos_flow = (70, 275) cv2.putText(images[i], names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(images[i], flows[i], pos_flow, fnt, 0.5, (0, 0, 0), 1) images = torch.Tensor(images.astype(np.float32)) images = images.permute(0, 3, 1, 2) return images
def vis_pf(vis, dataloader, theta, theta_weak, theta_inver, theta_weak_inver, results, results_weak, dataset_name, use_cuda=True): # Visualize watch images affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) pt = PointTnf(use_cuda=use_cuda) watch_images = torch.ones(len(dataloader) * 6, 3, 280, 240) watch_keypoints = -torch.ones(len(dataloader) * 6, 2, 20) if use_cuda: watch_images = watch_images.cuda() watch_keypoints = watch_keypoints.cuda() num_points = np.ones(len(dataloader) * 6).astype(np.int8) correct_index = list() image_names = list() metrics = list() # Colors for keypoints cmap = plt.get_cmap('tab20') colors = list() for c in range(20): r = cmap(c)[0] * 255 g = cmap(c)[1] * 255 b = cmap(c)[2] * 255 colors.append((b, g, r)) fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and theta_inver theta_aff = theta['aff'][batch_idx].unsqueeze(0) theta_aff_tps = theta['aff_tps'][batch_idx].unsqueeze(0) theta_weak_aff = theta_weak['aff'][batch_idx].unsqueeze(0) theta_weak_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0) theta_aff_inver = theta_inver['aff'][batch_idx].unsqueeze(0) theta_aff_tps_inver = theta_inver['aff_tps'][batch_idx].unsqueeze(0) theta_weak_aff_inver = theta_weak_inver['aff'][batch_idx].unsqueeze(0) theta_weak_aff_tps_inver = theta_weak_inver['aff_tps'][ batch_idx].unsqueeze(0) # Warped image warped_aff = affTnf(batch['source_image'], theta_aff) warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps) warped_weak_aff = affTnf(batch['source_image'], theta_weak_aff) warped_weak_aff_tps = tpsTnf(warped_weak_aff, theta_weak_aff_tps) watch_images[batch_idx * 6, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * 6 + 1, :, 0:240, :] = warped_aff watch_images[batch_idx * 6 + 2, :, 0:240, :] = warped_aff_tps watch_images[batch_idx * 6 + 3, :, 0:240, :] = batch['target_image'] watch_images[batch_idx * 6 + 4, :, 0:240, :] = warped_weak_aff watch_images[batch_idx * 6 + 5, :, 0:240, :] = warped_weak_aff_tps # Warped keypoints source_im_size = batch['source_im_info'][:, 0:3] target_im_size = batch['target_im_info'][:, 0:3] source_points = batch['source_points'] target_points = batch['target_points'] source_points_norm = PointsToUnitCoords(P=source_points, im_size=source_im_size) target_points_norm = PointsToUnitCoords(P=target_points, im_size=target_im_size) warped_points_aff_norm = pt.affPointTnf(theta=theta_aff_inver, points=source_points_norm) warped_points_aff = PointsToPixelCoords(P=warped_points_aff_norm, im_size=target_im_size) pck_aff, index_aff, N_pts = pck(target_points, warped_points_aff, dataset_name) warped_points_aff = relocate(warped_points_aff, target_im_size) warped_points_aff_tps_norm = pt.tpsPointTnf(theta=theta_aff_tps_inver, points=source_points_norm) warped_points_aff_tps_norm = pt.affPointTnf( theta=theta_aff_inver, points=warped_points_aff_tps_norm) warped_points_aff_tps = PointsToPixelCoords( P=warped_points_aff_tps_norm, im_size=target_im_size) pck_aff_tps, index_aff_tps, _ = pck(target_points, warped_points_aff_tps, dataset_name) warped_points_aff_tps = relocate(warped_points_aff_tps, target_im_size) warped_points_weak_aff_norm = pt.affPointTnf( theta=theta_weak_aff_inver, points=source_points_norm) warped_points_weak_aff = PointsToPixelCoords( P=warped_points_weak_aff_norm, im_size=target_im_size) pck_weak_aff, index_weak_aff, _ = pck(target_points, warped_points_weak_aff, dataset_name) warped_points_weak_aff = relocate(warped_points_weak_aff, target_im_size) warped_points_weak_aff_tps_norm = pt.tpsPointTnf( theta=theta_weak_aff_tps_inver, points=source_points_norm) warped_points_weak_aff_tps_norm = pt.affPointTnf( theta=theta_weak_aff_inver, points=warped_points_weak_aff_tps_norm) warped_points_weak_aff_tps = PointsToPixelCoords( P=warped_points_weak_aff_tps_norm, im_size=target_im_size) pck_weak_aff_tps, index_weak_aff_tps, _ = pck( target_points, warped_points_weak_aff_tps, dataset_name) warped_points_weak_aff_tps = relocate(warped_points_weak_aff_tps, target_im_size) watch_keypoints[batch_idx * 6, :, :N_pts] = relocate( batch['source_points'], source_im_size)[:, :, :N_pts] watch_keypoints[batch_idx * 6 + 1, :, :N_pts] = warped_points_aff[:, :, :N_pts] watch_keypoints[batch_idx * 6 + 2, :, :N_pts] = warped_points_aff_tps[:, :, :N_pts] watch_keypoints[batch_idx * 6 + 3, :, :N_pts] = relocate( batch['target_points'], target_im_size)[:, :, :N_pts] watch_keypoints[batch_idx * 6 + 4, :, :N_pts] = warped_points_weak_aff[:, :, :N_pts] watch_keypoints[ batch_idx * 6 + 5, :, :N_pts] = warped_points_weak_aff_tps[:, :, :N_pts] num_points[batch_idx * 6:batch_idx * 6 + 6] = N_pts correct_index.append(np.arange(N_pts)) correct_index.append(index_aff) correct_index.append(index_aff_tps) correct_index.append(np.arange(N_pts)) correct_index.append(index_weak_aff) correct_index.append(index_weak_aff_tps) image_names.append('Source') image_names.append('Aff') image_names.append('Aff_tps') image_names.append('Target') image_names.append('Rocco_aff') image_names.append('Rocco_aff_tps') metrics.append('') metrics.append('PCK: {:.2%}'.format(pck_aff)) metrics.append('PCK: {:.2%}'.format(pck_aff_tps)) metrics.append('') metrics.append('PCK: {:.2%}'.format(pck_weak_aff)) metrics.append('PCK: {:.2%}'.format(pck_weak_aff_tps)) opts = dict(jpgquality=100, title=dataset_name) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) watch_keypoints = watch_keypoints.cpu().numpy() for i in range(watch_images.shape[0]): pos_name = (80, 255) if (i + 1) % 6 == 1 or (i + 1) % 6 == 4: pos_pck = (0, 0) else: pos_pck = (70, 275) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0), 1) if (i + 1) % 6 == 4: for j in range(num_points[i]): cv2.drawMarker( watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA) else: for j in correct_index[i]: cv2.drawMarker( watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA) cv2.drawMarker(watch_images[i], (watch_keypoints[i + 3 - (i % 6), 0, j], watch_keypoints[i + 3 - (i % 6), 1, j]), colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3), opts=opts)
def vis_caltech(vis, dataloader, theta, theta_weak, results, results_weak, title, use_cuda=True): # Visualize watch images affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) watch_images = torch.ones(len(dataloader) * 6, 3, 280, 240) if use_cuda: watch_images = watch_images.cuda() image_names = list() lt_acc = list() iou = list() fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) theta_aff = theta['aff'][batch_idx].unsqueeze(0) theta_aff_tps = theta['aff_tps'][batch_idx].unsqueeze(0) theta_weak_aff = theta_weak['aff'][batch_idx].unsqueeze(0) theta_weak_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0) # Warped image warped_aff = affTnf(batch['source_image'], theta_aff) warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps) warped_weak_aff = affTnf(batch['source_image'], theta_weak_aff) warped_weak_aff_tps = tpsTnf(warped_weak_aff, theta_weak_aff_tps) watch_images[batch_idx * 6, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * 6 + 1, :, 0:240, :] = warped_aff watch_images[batch_idx * 6 + 2, :, 0:240, :] = warped_aff_tps watch_images[batch_idx * 6 + 3, :, 0:240, :] = batch['target_image'] watch_images[batch_idx * 6 + 4, :, 0:240, :] = warped_weak_aff watch_images[batch_idx * 6 + 5, :, 0:240, :] = warped_weak_aff_tps image_names.append('Source') image_names.append('Aff') image_names.append('Aff_tps') image_names.append('Target') image_names.append('Rocco_aff') image_names.append('Rocco_aff_tps') lt_acc.append('') lt_acc.append('LT-ACC: {:.2f}'.format( float(results['aff']['label_transfer_accuracy'][batch_idx]))) lt_acc.append('LT-ACC: {:.2f}'.format( float(results['aff_tps']['label_transfer_accuracy'][batch_idx]))) lt_acc.append('') lt_acc.append('LT-ACC: {:.2f}'.format( float(results_weak['aff']['label_transfer_accuracy'][batch_idx]))) lt_acc.append('LT-ACC: {:.2f}'.format( float(results_weak['aff_tps']['label_transfer_accuracy'] [batch_idx]))) iou.append('') iou.append('IoU: {:.2f}'.format( float(results['aff']['intersection_over_union'][batch_idx]))) iou.append('IoU: {:.2f}'.format( float(results['aff_tps']['intersection_over_union'][batch_idx]))) iou.append('') iou.append('IoU: {:.2f}'.format( float(results_weak['aff']['intersection_over_union'][batch_idx]))) iou.append('IoU: {:.2f}'.format( float(results_weak['aff_tps']['intersection_over_union'] [batch_idx]))) opts = dict(jpgquality=100, title=title) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): pos_name = (80, 255) if (i + 1) % 6 == 1 or (i + 1) % 6 == 4: pos_lt_ac = (0, 0) pos_iou = (0, 0) else: pos_lt_ac = (10, 275) pos_iou = (140, 275) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], lt_acc[i], pos_lt_ac, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], iou[i], pos_iou, fnt, 0.5, (0, 0, 0), 1) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=5), opts=opts)
def vis_fn(vis, train_loss, val_pck, train_lr, epoch, num_epochs, dataloader, theta, thetai, results, geometric_model='tps', use_cuda=True): # Visualize watch images if geometric_model == 'tps': geoTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) elif geometric_model == 'affine': geoTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) pt = PointTPS(use_cuda=use_cuda) group_size = 3 watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240) watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20) if use_cuda: watch_images = watch_images.cuda() watch_keypoints = watch_keypoints.cuda() num_points = np.ones(len(dataloader) * group_size).astype(np.int8) correct_index = list() image_names = list() metrics = list() # Colors for keypoints cmap = plt.get_cmap('tab20') colors = list() for c in range(20): r = cmap(c)[0] * 255 g = cmap(c)[1] * 255 b = cmap(c)[2] * 255 colors.append((b, g, r)) fnt = cv2.FONT_HERSHEY_COMPLEX theta, thetai = swap(theta, thetai) # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) batch['source_image'], batch['target_image'] = swap(batch['source_image'], batch['target_image']) batch['source_im_info'], batch['target_im_info'] = swap(batch['source_im_info'], batch['target_im_info']) batch['source_points'], batch['target_points'] = swap(batch['source_points'], batch['target_points']) # Theta and thetai if geometric_model == 'tps': theta_batch = theta['tps'][batch_idx].unsqueeze(0) theta_batch_inver = thetai['tps'][batch_idx].unsqueeze(0) elif geometric_model == 'affine': theta_batch = theta['aff'][batch_idx].unsqueeze(0) theta_batch_inver = thetai['aff'][batch_idx].unsqueeze(0) # Warped image warped_image = geoTnf(batch['source_image'], theta_batch) watch_images[batch_idx * group_size, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_image watch_images[batch_idx * group_size + 2, :, 0:240, :] = batch['target_image'] # Warped keypoints source_im_size = batch['source_im_info'][:, 0:3] target_im_size = batch['target_im_info'][:, 0:3] source_points = batch['source_points'] target_points = batch['target_points'] source_points_norm = PointsToUnitCoords(P=source_points, im_size=source_im_size) target_points_norm = PointsToUnitCoords(P=target_points, im_size=target_im_size) if geometric_model == 'tps': warped_points_norm = pt.tpsPointTnf(theta=theta_batch_inver, points=source_points_norm) elif geometric_model == 'affine': warped_points_norm = pt.affPointTnf(theta=theta_batch_inver, points=source_points_norm) warped_points = PointsToPixelCoords(P=warped_points_norm, im_size=target_im_size) _, index_correct, N_pts = pck(target_points, warped_points) watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate(batch['source_points'], source_im_size)[:, :, :N_pts] watch_keypoints[batch_idx * group_size + 1, :, :N_pts] = relocate(warped_points, target_im_size)[:, :, :N_pts] watch_keypoints[batch_idx * group_size + 2, :, :N_pts] = relocate(batch['target_points'], target_im_size)[:, :, :N_pts] num_points[batch_idx * group_size:batch_idx * group_size + group_size] = N_pts correct_index.append(np.arange(N_pts)) correct_index.append(index_correct) correct_index.append(np.arange(N_pts)) image_names.append('Source') if geometric_model == 'tps': image_names.append('TPS') elif geometric_model == 'affine': image_names.append('Affine') image_names.append('Target') metrics.append('') if geometric_model == 'tps': metrics.append('PCK: {:.2%}'.format(float(results['tps']['pck'][batch_idx]))) elif geometric_model == 'affine': metrics.append('PCK: {:.2%}'.format(float(results['aff']['pck'][batch_idx]))) metrics.append('') opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped target') # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) watch_keypoints = watch_keypoints.cpu().numpy() for i in range(watch_images.shape[0]): pos_name = (80, 255) if (i + 1) % group_size == 1 or (i + 1) % group_size == 0: pos_pck = (0, 0) else: pos_pck = (70, 275) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0), 1) if (i + 1) % group_size == 0: for j in range(num_points[i]): cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA) else: for j in correct_index[i]: cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA) cv2.drawMarker(watch_images[i], (watch_keypoints[i + (group_size - 1) - (i % group_size), 0, j], watch_keypoints[i + (group_size - 1) - (i % group_size), 1, j]), colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3), opts=opts) if epoch == num_epochs: if geometric_model == 'affine': sub_str = 'Affine' elif geometric_model == 'tps': sub_str = 'TPS' epochs = np.arange(1, num_epochs+1) # Visualize train loss opts_loss = dict(xlabel='Epoch', ylabel='Loss', title='GM ResNet101 ' + sub_str + ' Training Loss', legend=['Loss'], width=2000) vis.line(train_loss, epochs, opts=opts_loss) # Visualize val pck opts_pck = dict(xlabel='Epoch', ylabel='Val PCK', title='GM ResNet101 ' + sub_str + ' Val PCK', legend=['PCK'], width=2000) vis.line(val_pck, epochs, opts=opts_pck) # Visualize train lr opts_lr = dict(xlabel='Epoch', ylabel='Learning Rate', title='GM ResNet101 ' + sub_str + ' Training Learning Rate', legend=['LR'], width=2000) vis.line(train_lr, epochs, opts=opts_lr)
def train_fn(epoch, model, loss_fn, loss_cycle_fn, loss_jitter_fn, lambda_c, lambda_j, optimizer, dataloader, triple_generation, geometric_model='tps', use_cuda=True, log_interval=100, vis=None): """ Train the model with synthetically training triple: {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL. 1. Train the transformation parameters theta_st from source image to target image; 2. Train the transformation parameters theta_tr from target image to refer image; 3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between theta and theta_GT. """ geoTnf = GeometricTnf(geometric_model=geometric_model, use_cuda=use_cuda) epoch_loss = 0 if (epoch % 5 == 0 or epoch == 1) and vis is not None: stride_images = len(dataloader) / 3 group_size = 9 watch_images = torch.ones(group_size * 4, 3, 340, 340).cuda() watch_theta = torch.zeros(8, 36).cuda() fnt = cv2.FONT_HERSHEY_COMPLEX stride_loss = len(dataloader) / 105 iter_loss = np.zeros(106) begin = time.time() for batch_idx, batch in enumerate(dataloader): ''' Move input batch to gpu ''' # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240) # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv) if use_cuda: batch = batch_cuda(batch) ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}''' batch_triple = triple_generation(batch) ''' Train the model ''' optimizer.zero_grad() loss = 0 # Predict tps parameters between images # theta.shape: (batch_size, 18) for tps, theta.shape: (batch_size, 6) for affine theta_st, theta_ts, theta_tr, theta_rt = model( batch_triple) # from source image to target image loss_match = loss_fn(theta_st=theta_st, theta_tr=theta_tr, theta_GT=batch_triple['theta_GT']) loss_cycle_st = loss_cycle_fn(theta_AB=theta_st, theta_BA=theta_ts) loss_cycle_ts = loss_cycle_fn(theta_AB=theta_ts, theta_BA=theta_st) loss_cycle_tr = loss_cycle_fn(theta_AB=theta_tr, theta_BA=theta_rt) loss_cycle_rt = loss_cycle_fn(theta_AB=theta_rt, theta_BA=theta_tr) loss_jitter = loss_jitter_fn(theta_st=theta_st, theta_tr=theta_tr) loss = loss_match + lambda_c * ( loss_cycle_st + loss_cycle_ts + loss_cycle_tr + loss_cycle_rt) / 4 + lambda_j * loss_jitter # loss = loss_match + lambda_c * (loss_cycle_st + loss_cycle_ts + loss_cycle_tr + loss_cycle_rt) / 4 loss.backward() optimizer.step() epoch_loss += loss.item() if (batch_idx + 1) % log_interval == 0: end = time.time() print( 'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s' .format(epoch, batch_idx + 1, len(dataloader), (batch_idx + 1) / len(dataloader), loss.item(), batch_idx + 1, end - begin)) if (epoch % 5 == 0 or epoch == 1) and vis is not None: if (batch_idx + 1) % stride_images == 0 or batch_idx == 0: watch_images, watch_theta = add_watch( watch_images, watch_theta, batch_triple, geoTnf, theta_st, theta_tr, int((batch_idx + 1) / stride_images) * group_size, int((batch_idx + 1) / stride_images)) # if batch_idx <= 19: # watch_images, image_names = add_watch(watch_images, watch_theta, batch_st, batch_tr, geoTnf, theta_st, theta_tr, batch_idx * group_size, batch_idx) # if batch_idx == 19: # opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_sr target warped_tr refer warped_sr') # watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False) # watch_images *= 255.0 # watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) # for i in range(watch_images.shape[0]): # if i % group_size == 0: # cp_norm = watch_theta[int(i / group_size), :18].view(1, 2, -1) # watch_images[i] = draw_grid(watch_images[i], cp_norm) # # cp_norm = watch_theta[int(i / group_size), 18:].view(1, 2, -1) # watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm) # # cp_norm = watch_theta[int(i / group_size) + 1, :18].view(1, 2, -1) # watch_images[i + 3] = draw_grid(watch_images[i + 3], cp_norm) # # cp_norm = watch_theta[int(i / group_size) + 1, 18:].view(1, 2, -1) # watch_images[i + 4] = draw_grid(watch_images[i + 4], cp_norm) # # watch_images = torch.Tensor(watch_images.astype(np.float32)) # watch_images = watch_images.permute(0, 3, 1, 2) # vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3), opts=opts) if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0: iter_loss[int((batch_idx + 1) / stride_loss)] = epoch_loss / (batch_idx + 1) # watch_images = normalize_image(batch_tr['target_image'], forward=False) * 255.0 # vis.images(watch_images, nrow=8, padding=3) end = time.time() # Visualize watch images & train loss if (epoch % 5 == 0 or epoch == 1) and vis is not None: opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_sr target warped_tr refer warped_sr') watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): if i % group_size == 0: cp_norm = watch_theta[int(i / group_size), :18].view(1, 2, -1) watch_images[i] = draw_grid(watch_images[i], cp_norm) cp_norm = watch_theta[int(i / group_size), 18:].view(1, 2, -1) watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm) cp_norm = watch_theta[int(i / group_size) + 1, :18].view( 1, 2, -1) watch_images[i + 3] = draw_grid(watch_images[i + 3], cp_norm) cp_norm = watch_theta[int(i / group_size) + 1, 18:].view(1, 2, -1) watch_images[i + 4] = draw_grid(watch_images[i + 4], cp_norm) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3), opts=opts) opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')', ylabel='Loss', title='GM ResNet101 ' + geometric_model + ' Training Loss in Epoch ' + str(epoch), legend=['Loss'], width=2000) vis.line(iter_loss, np.arange(106), opts=opts_loss) epoch_loss /= len(dataloader) print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format( epoch_loss, end - begin)) return epoch_loss, end - begin
def train_fn(epoch, model, loss_fn, optimizer, dataloader, use_cuda=True, log_interval=100, vis=None): """ Train cosegmentation model: {source image, target image} from PF-PASCAL. 1. Train the co-object masks on source image and target image; 2. Compute loss. """ epoch_loss = 0 # if (epoch % 5 == 0 or epoch == 1) and vis is not None: if vis is not None: stride_images = len(dataloader) / 4 group_size = 6 watch_images = torch.ones(group_size * 5, 3, 260, 240).cuda() image_names = list() fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda() stride_loss = len(dataloader) / 92 iter_loss = np.zeros(93) begin = time.time() for batch_idx, batch in enumerate(dataloader): ''' Move input batch to gpu ''' if use_cuda: batch = batch_cuda(batch) ''' Train the model ''' optimizer.zero_grad() mask_A, mask_B = model(batch) loss = loss_fn(mask_A, mask_B, batch['source_image'], batch['target_image']) loss.backward() optimizer.step() epoch_loss += loss.item() if (batch_idx + 1) % log_interval == 0: end = time.time() print( 'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s' .format(epoch, batch_idx + 1, len(dataloader), (batch_idx + 1) / len(dataloader), loss.item(), batch_idx + 1, end - begin)) # if (epoch % 5 == 0 or epoch == 1) and vis is not None: if vis is not None: if (batch_idx + 1) % stride_images == 0 or batch_idx == 0: watch_images, image_names = add_watch( watch_images, image_names, batch, mask_A, mask_B, int((batch_idx + 1) / stride_images) * group_size) # if batch_idx <= 4: # watch_images, image_names = add_watch(watch_images, image_names, batch, mask_A, mask_B, batch_idx * group_size) # if batch_idx == 4: # opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' image mask') # watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) # watch_images *= 255.0 # watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) # for i in range(len(image_names)): # cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1) # watch_images = torch.Tensor(watch_images.astype(np.float32)) # watch_images = watch_images.permute(0, 3, 1, 2) # vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts) if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0: iter_loss[int((batch_idx + 1) / stride_loss)] = epoch_loss / (batch_idx + 1) end = time.time() # Visualize watch images & train loss # if (epoch % 5 == 0 or epoch == 1) and vis is not None: if vis is not None: opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' image mask') watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')', ylabel='Loss', title='CoSegmentation Training Loss in Epoch ' + str(epoch), legend=['Loss'], width=2000) vis.line(iter_loss, np.arange(93), opts=opts_loss) epoch_loss /= len(dataloader) print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format( epoch_loss, end - begin)) return epoch_loss, end - begin
def vis_fn(vis, train_loss, val_iou, train_lr, epoch, num_epochs, dataloader, results, masks_A, masks_B, use_cuda=True): # Visualize watch images group_size = 6 watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240) if use_cuda: watch_images = watch_images.cuda() image_names = list() metrics = list() # Colors for keypoints cmap = plt.get_cmap('tab20') colors = list() for c in range(20): r = cmap(c)[0] * 255 g = cmap(c)[1] * 255 b = cmap(c)[2] * 255 colors.append((b, g, r)) fnt = cv2.FONT_HERSHEY_COMPLEX for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and theta_inver watch_images[batch_idx * group_size, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 0:240, :] = torch.mul(batch['source_image'], batch['source_mask']) # watch_images[batch_idx * group_size + 2, :, 0:240, :] = torch.mul(batch['source_image'], masks_A[batch_idx]) mask_A = masks_A[batch_idx].gt(0.5).float() watch_images[batch_idx * group_size + 2, :, 0:240, :] = torch.mul(batch['source_image'], mask_A) watch_images[batch_idx * group_size + 3, :, 0:240, :] = batch['target_image'] watch_images[batch_idx * group_size + 4, :, 0:240, :] = torch.mul(batch['target_image'], batch['target_mask']) # watch_images[batch_idx * group_size + 5, :, 0:240, :] = torch.mul(batch['target_image'], masks_B[batch_idx]) mask_B = masks_B[batch_idx].gt(0.5).float() watch_images[batch_idx * group_size + 5, :, 0:240, :] = torch.mul(batch['target_image'], mask_B) image_names.append('Source') image_names.append('Mask_GT') image_names.append('Mask') image_names.append('Target') image_names.append('Mask_GT') image_names.append('Mask') metrics.append('') metrics.append('') metrics.append('IoU: {:.2%}'.format(float(results[batch_idx, 0]))) metrics.append('') metrics.append('') metrics.append('IoU: {:.2%}'.format(float(results[batch_idx, 1]))) opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' image mask_gt mask') # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): pos_name = (80, 255) if (i + 1) % group_size == 3 or (i + 1) % group_size == 0: pos_iou = (70, 275) else: pos_iou = (0, 0) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], metrics[i], pos_iou, fnt, 0.5, (0, 0, 0), 1) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3), opts=opts) if epoch == num_epochs: epochs = np.arange(1, num_epochs + 1) # Visualize train loss opts_loss = dict(xlabel='Epoch', ylabel='Loss', title='CoSegmentation Training Loss', legend=['Loss'], width=2000) vis.line(train_loss, epochs, opts=opts_loss) # Visualize val pck opts_pck = dict(xlabel='Epoch', ylabel='Val PCK', title='CoSegmentation Val IoU', legend=['PCK'], width=2000) vis.line(val_iou, epochs, opts=opts_pck) # Visualize train lr opts_lr = dict(xlabel='Epoch', ylabel='Learning Rate', title='CoSegmentation Training Learning Rate', legend=['LR'], width=2000) vis.line(train_lr, epochs, opts=opts_lr)
def vis_fn(vis, train_loss, val_pck, train_lr, epoch, num_epochs, dataloader, theta, thetai, results, geometric_model='tps', use_cuda=True): geoTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) group_size = 3 watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340) if use_cuda: watch_images = watch_images.cuda() # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and thetai theta_batch = theta['tps'][batch_idx].unsqueeze(0) # Warped image warped_image = geoTnf(batch['source_image'], theta_batch) watch_images[batch_idx * group_size, :, 50:290, 50:290] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 50:290, 50:290] = warped_image watch_images[batch_idx * group_size + 2, :, 50:290, 50:290] = batch['target_image'] opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped target') # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): if i % group_size == 0: cp_norm = theta['tps'][int(i / group_size)][:18].view(1, 2, -1) watch_images[i] = draw_grid(watch_images[i], cp_norm) cp_norm = theta['tps'][int(i / group_size)][18:].view(1, 2, -1) watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3), opts=opts) if epoch == num_epochs: if geometric_model == 'affine': sub_str = 'Affine' elif geometric_model == 'tps': sub_str = 'TPS' epochs = np.arange(1, num_epochs + 1) # Visualize train loss opts_loss = dict(xlabel='Epoch', ylabel='Loss', title='GM ResNet101 ' + sub_str + ' Training Loss', legend=['Loss'], width=2000) vis.line(train_loss, epochs, opts=opts_loss) # Visualize val pck opts_pck = dict(xlabel='Epoch', ylabel='Val PCK', title='GM ResNet101 ' + sub_str + ' Val PCK', legend=['PCK'], width=2000) vis.line(val_pck, epochs, opts=opts_pck) # Visualize train lr opts_lr = dict(xlabel='Epoch', ylabel='Learning Rate', title='GM ResNet101 ' + sub_str + ' Training Learning Rate', legend=['LR'], width=2000) vis.line(train_lr, epochs, opts=opts_lr)
def train_fn_detect(epoch, model, faster_rcnn, aff_theta, loss_fn, optimizer, dataloader, triple_generation, use_cuda=True, log_interval=100, vis=None, show=False): """ Train the model with synthetically training triple: {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL. 1. Train the transformation parameters theta_st from source image to target image; 2. Train the transformation parameters theta_tr from target image to refer image; 3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between theta and theta_GT. """ tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) epoch_loss = 0 if (epoch % 5 == 0 or epoch == 1) and vis is not None: stride_images = len(dataloader) / 3 watch_images = torch.Tensor(24, 3, 240, 240) # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda() stride_loss = len(dataloader) / 35 iter_loss = np.zeros(36) begin = time.time() for batch_idx, batch in enumerate(dataloader): ''' Move input batch to gpu ''' # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240) # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv) if use_cuda: batch = batch_cuda(batch) ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}''' batch_triple = triple_generation(batch) ''' Train the model ''' optimizer.zero_grad() # Predict tps parameters between images box_info_s = faster_rcnn( im_data=batch_triple['source_im'], im_info=batch_triple['source_im_info'][:, 3:], gt_boxes=batch_triple['source_gt_boxes'], num_boxes=batch_triple['source_num_boxes'])[0:3] box_info_t = faster_rcnn( im_data=batch_triple['target_im'], im_info=batch_triple['target_im_info'][:, 3:], gt_boxes=batch_triple['target_gt_boxes'], num_boxes=batch_triple['target_num_boxes'])[0:3] box_info_r = faster_rcnn( im_data=batch_triple['refer_im'], im_info=batch_triple['refer_im_info'][:, 3:], gt_boxes=batch_triple['refer_gt_boxes'], num_boxes=batch_triple['refer_num_boxes'])[0:3] all_box_s = select_boxes(rois=box_info_s[0], cls_prob=box_info_s[1], bbox_pred=box_info_s[2], im_infos=batch_triple['source_im_info'][:, 3:]) all_box_t = select_boxes(rois=box_info_t[0], cls_prob=box_info_t[1], bbox_pred=box_info_t[2], im_infos=batch_triple['target_im_info'][:, 3:]) all_box_r = select_boxes(rois=box_info_r[0], cls_prob=box_info_r[1], bbox_pred=box_info_r[2], im_infos=batch_triple['source_im_info'][:, 3:]) box_s, box_t, box_r = select_box(all_box_s, all_box_t, all_box_r) theta_st = aff_theta(boxes_s=box_s, boxes_t=box_t) theta_tr = aff_theta(boxes_s=box_t, boxes_t=box_r) # theta.shape: (batch_size, 18) for tps batch_st = { 'source_image': batch_triple['source_image'], 'target_image': batch_triple['target_image'] } batch_tr = { 'source_image': batch_triple['target_image'], 'target_image': batch_triple['refer_image'] } theta_aff_tps_st, theta_aff_st = model( batch_st, theta_st) # from source image to target image theta_aff_tps_tr, theta_aff_tr = model( batch_tr, theta_tr) # from target image to refer image # show_images(batch_st=batch_st, batch_tr=batch_tr, box_s=box_s, box_t=box_t, box_r=box_r) loss = loss_fn(theta_st=theta_aff_tps_st, theta_tr=theta_aff_tps_tr, theta_GT=batch_triple['theta_GT']) loss.backward() optimizer.step() epoch_loss += loss.item() if (batch_idx + 1) % log_interval == 0: end = time.time() print( 'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s' .format(epoch, batch_idx + 1, len(dataloader), (batch_idx + 1) / len(dataloader), loss.item(), batch_idx + 1, end - begin)) if (epoch % 5 == 0 or epoch == 1) and vis is not None: if (batch_idx + 1) % stride_images == 0 or batch_idx == 0: watch_images = add_watch( watch_images, batch_st, batch_tr, tpsTnf, theta_aff_tps_st, theta_aff_tps_tr, int((batch_idx + 1) / stride_images) * 6) if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0: iter_loss[int((batch_idx + 1) / stride_loss)] = epoch_loss / (batch_idx + 1) # tmp_images = batch_triple['target_im'].permute(0, 2, 3, 1) + pixel_means # tmp_images = tmp_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # vis.image(torchvision.utils.make_grid(tmp_images, nrow=1, padding=3)) if show: # if dual: # warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st_1) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1) # warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr_1) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) # # warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr) # warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) # # warped_image_aff_tps = tpsTnf(batch_st['source_image'], theta_aff_tps_st) # warped_image_aff_tps_2 = tpsTnf(batch_tr['source_image'], theta_aff_tps_tr) # warped_image_aff_tps_3 = tpsTnf(warped_image_aff_tps, theta_aff_tps_tr) # show_images(batch_st, batch_tr, warped_image_aff_tps.detach(), warped_image_aff_tps_2.detach(), warped_image_aff_tps_3.detach()) # # warped_image = affTnf(batch_st['source_image'], theta_aff_st_1) # warped_image = affTnf(warped_image, theta_aff_st) # warped_image = tpsTnf(warped_image, theta_aff_tps_st) # warped_image_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1) # warped_image_2 = affTnf(warped_image_2, theta_aff_tr) # warped_image_2 = tpsTnf(warped_image_2, theta_aff_tps_tr) # warped_image_3 = affTnf(warped_image, theta_aff_tr_1) # warped_image_3 = affTnf(warped_image_3, theta_aff_tr) # warped_image_3 = tpsTnf(warped_image_3, theta_aff_tps_tr) # show_images(batch_st, batch_tr, warped_image.detach(), warped_image_2.detach(), warped_image_3.detach()) # else: # warped_image_aff = affTnf(batch_st['source_image'], theta_st) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_tr) # warped_image_aff_3 = affTnf(warped_image_aff, theta_tr) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) warped_image_tps = tpsTnf(batch_st['source_image'], theta_aff_tps_st) warped_image_tps_2 = tpsTnf(batch_tr['source_image'], theta_aff_tps_tr) warped_image_tps_3 = tpsTnf(warped_image_tps, theta_aff_tps_tr) show_images(batch_triple, warped_image_tps.detach(), warped_image_tps_2.detach(), warped_image_tps_3.detach(), box_s, box_t, box_r) end = time.time() # Visualize watch images & train loss if (epoch % 5 == 0 or epoch == 1) and vis is not None: opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_sr target warped_tr refer warped_sr') # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) watch_images = normalize_image(watch_images, forward=False) * 255.0 vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3), opts=opts) # vis.images(watch_images, nrow=6, padding=3, opts=opts) opts_loss = dict( xlabel='Iterations (' + str(stride_loss) + ')', ylabel='Loss', title='GM ResNet101 Detect&Affine&TPS Training Loss in Epoch ' + str(epoch), legend=['Loss'], width=2000) vis.line(iter_loss, np.arange(36), opts=opts_loss) epoch_loss /= len(dataloader) print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format( epoch_loss, end - begin)) return epoch_loss, end - begin
def vis_control2(vis, dataloader, theta_1, theta_2, dataset_name, use_cuda=True): # Visualize watch images tpsTnf = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda) group_size = 5 watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340) if use_cuda: watch_images = watch_images.cuda() # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and theta_inver theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0) theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0) # Warped image warped_tps_1 = tpsTnf(batch['source_image'], theta_tps_1) warped_tps_2 = tpsTnf(batch['source_image'], theta_tps_2) watch_images[batch_idx * group_size, :, 50:290, 50:290] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 50:290, 50:290] = warped_tps_1 watch_images[batch_idx * group_size + 2, :, 50:290, 50:290] = batch['source_image'] watch_images[batch_idx * group_size + 3, :, 50:290, 50:290] = warped_tps_2 watch_images[batch_idx * group_size + 4, :, 50:290, 50:290] = batch['target_image'] opts = dict(jpgquality=100, title=dataset_name) watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) im_size = torch.Tensor([[240, 240]]).cuda() for i in range(watch_images.shape[0]): if i % group_size < 4: if i % group_size == 0: cp_norm = theta_1['tps'][int(i / group_size)][:18].view( 1, 2, -1) if i % group_size == 1: cp_norm = theta_1['tps'][int(i / group_size)][18:].view( 1, 2, -1) if i % group_size == 2: cp_norm = theta_2['tps'][int(i / group_size)][:18].view( 1, 2, -1) if i % group_size == 3: cp_norm = theta_2['tps'][int(i / group_size)][18:].view( 1, 2, -1) watch_images[i] = draw_grid(watch_images[i], cp_norm) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5), opts=opts)
def vis_pf(vis, dataloader, theta_1, theta_2, theta_inver_1, theta_inver_2, results_1, results_2, dataset_name, use_cuda=True): # Visualize watch images tpsTnf_1 = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) tpsTnf_2 = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda) pt_1 = PointTnf(use_cuda=use_cuda) pt_2 = PointTPS(use_cuda=use_cuda) group_size = 4 watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240) watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20) if use_cuda: watch_images = watch_images.cuda() watch_keypoints = watch_keypoints.cuda() num_points = np.ones(len(dataloader) * 6).astype(np.int8) correct_index = list() image_names = list() metrics = list() # Colors for keypoints cmap = plt.get_cmap('tab20') colors = list() for c in range(20): r = cmap(c)[0] * 255 g = cmap(c)[1] * 255 b = cmap(c)[2] * 255 colors.append((b, g, r)) fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and theta_inver theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0) theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0) thetai_tps_1 = theta_inver_1['tps'][batch_idx].unsqueeze(0) thetai_tps_2 = theta_inver_2['tps'][batch_idx].unsqueeze(0) # Warped image warped_tps_1 = tpsTnf_1(batch['source_image'], theta_tps_1) warped_tps_2 = tpsTnf_2(batch['source_image'], theta_tps_2) watch_images[batch_idx * group_size, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_tps_1 watch_images[batch_idx * group_size + 2, :, 0:240, :] = warped_tps_2 watch_images[batch_idx * group_size + 3, :, 0:240, :] = batch['target_image'] # Warped keypoints source_im_size = batch['source_im_info'][:, 0:3] target_im_size = batch['target_im_info'][:, 0:3] source_points = batch['source_points'] target_points = batch['target_points'] source_points_norm = PointsToUnitCoords(P=source_points, im_size=source_im_size) target_points_norm = PointsToUnitCoords(P=target_points, im_size=target_im_size) warped_points_tps_norm_1 = pt_1.tpsPointTnf(theta=thetai_tps_1, points=source_points_norm) warped_points_tps_1 = PointsToPixelCoords(P=warped_points_tps_norm_1, im_size=target_im_size) pck_tps_1, index_tps_1, N_pts = pck(target_points, warped_points_tps_1, dataset_name) warped_points_tps_1 = relocate(warped_points_tps_1, target_im_size) warped_points_tps_norm_2 = pt_2.tpsPointTnf(theta=thetai_tps_2, points=source_points_norm) warped_points_tps_2 = PointsToPixelCoords(P=warped_points_tps_norm_2, im_size=target_im_size) pck_tps_2, index_tps_2, _ = pck(target_points, warped_points_tps_2, dataset_name) warped_points_tps_2 = relocate(warped_points_tps_2, target_im_size) watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate( batch['source_points'], source_im_size)[:, :, :N_pts] watch_keypoints[batch_idx * group_size + 1, :, :N_pts] = warped_points_tps_1[:, :, :N_pts] watch_keypoints[batch_idx * group_size + 2, :, :N_pts] = warped_points_tps_2[:, :, :N_pts] watch_keypoints[batch_idx * group_size + 3, :, :N_pts] = relocate( batch['target_points'], target_im_size)[:, :, :N_pts] num_points[batch_idx * group_size:batch_idx * group_size + group_size] = N_pts correct_index.append(np.arange(N_pts)) correct_index.append(index_tps_1) correct_index.append(index_tps_2) correct_index.append(np.arange(N_pts)) image_names.append('Source') image_names.append('TPS') image_names.append('TPS_Jitter') image_names.append('Target') metrics.append('') metrics.append('PCK: {:.2%}'.format(pck_tps_1)) metrics.append('PCK: {:.2%}'.format(pck_tps_2)) metrics.append('') opts = dict(jpgquality=100, title=dataset_name) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) watch_keypoints = watch_keypoints.cpu().numpy() for i in range(watch_images.shape[0]): pos_name = (80, 255) if (i + 1) % group_size == 1 or (i + 1) % group_size == 0: pos_pck = (0, 0) else: pos_pck = (70, 275) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0), 1) if (i + 1) % group_size == 0: for j in range(num_points[i]): cv2.drawMarker( watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA) else: for j in correct_index[i]: cv2.drawMarker( watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA) cv2.drawMarker(watch_images[i], (watch_keypoints[i + (group_size - 1) - (i % group_size), 0, j], watch_keypoints[i + (group_size - 1) - (i % group_size), 1, j]), colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=5), opts=opts)
def vis_control(vis, dataloader, theta_1, theta_2, dataset_name, use_cuda=True): # Visualize watch images tpsTnf_1 = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) tpsTnf_2 = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda) group_size = 5 watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340) if use_cuda: watch_images = watch_images.cuda() # Colors for keypoints cmap = plt.get_cmap('tab20') colors = list() for c in range(20): r = cmap(c)[0] * 255 g = cmap(c)[1] * 255 b = cmap(c)[2] * 255 colors.append((b, g, r)) fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) # Theta and theta_inver theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0) theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0) # Warped image warped_tps_1 = tpsTnf_1(batch['source_image'], theta_tps_1) warped_tps_2 = tpsTnf_2(batch['source_image'], theta_tps_2) watch_images[batch_idx * group_size, :, 50:290, 50:290] = batch['source_image'] watch_images[batch_idx * group_size + 1, :, 50:290, 50:290] = warped_tps_1 watch_images[batch_idx * group_size + 2, :, 50:290, 50:290] = batch['source_image'] watch_images[batch_idx * group_size + 3, :, 50:290, 50:290] = warped_tps_2 watch_images[batch_idx * group_size + 4, :, 50:290, 50:290] = batch['target_image'] opts = dict(jpgquality=100, title=dataset_name) watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) im_size = torch.Tensor([[240, 240]]).cuda() for i in range(watch_images.shape[0]): if i % group_size == 0: cp_norm = theta_1['tps'][int(i / group_size)].view(1, 2, -1) cp = PointsToPixelCoords(P=cp_norm, im_size=im_size) cp = cp.squeeze().cpu().numpy() + 50 for j in range(9): cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]), (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2, cv2.LINE_AA) for j in range(2): for k in range(3): # vertical grid cv2.line(watch_images[i], (cp[0, j + k * 3], cp[1, j + k * 3]), (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]), (0, 0, 255), 2, cv2.LINE_AA) # horizontal grid cv2.line(watch_images[i], (cp[0, j * 3 + k], cp[1, j * 3 + k]), (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]), (0, 0, 255), 2, cv2.LINE_AA) if i % group_size == 1: cp_norm = torch.Tensor( [-1, -1, -1, 0, 0, 0, 1, 1, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1]).cuda().view(1, 2, -1) cp = PointsToPixelCoords(P=cp_norm, im_size=im_size) cp = cp.squeeze().cpu().numpy() + 50 for j in range(9): cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]), (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2, cv2.LINE_AA) for j in range(1): for k in range(3): # vertical grid cv2.line(watch_images[i], (cp[0, j + k * 3], cp[1, j + k * 3]), (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]), (0, 0, 255), 2, cv2.LINE_AA) # horizontal grid cv2.line(watch_images[i], (cp[0, j * 3 + k], cp[1, j * 3 + k]), (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]), (0, 0, 255), 2, cv2.LINE_AA) if i % group_size == 2: cp_norm = theta_2['tps'][int(i / group_size)][:18].view(1, 2, -1) cp = PointsToPixelCoords(P=cp_norm, im_size=im_size) cp = cp.squeeze().cpu().numpy() + 50 for j in range(9): cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]), (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2, cv2.LINE_AA) for j in range(2): for k in range(3): # vertical grid cv2.line(watch_images[i], (cp[0, j + k * 3], cp[1, j + k * 3]), (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]), (0, 0, 255), 2, cv2.LINE_AA) # horizontal grid cv2.line(watch_images[i], (cp[0, j * 3 + k], cp[1, j * 3 + k]), (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]), (0, 0, 255), 2, cv2.LINE_AA) if i % group_size == 3: cp_norm = theta_2['tps'][int(i / group_size)][18:].view(1, 2, -1) cp = PointsToPixelCoords(P=cp_norm, im_size=im_size) cp = cp.squeeze().cpu().numpy() + 50 for j in range(9): cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]), (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2, cv2.LINE_AA) for j in range(2): for k in range(3): # vertical grid cv2.line(watch_images[i], (cp[0, j + k * 3], cp[1, j + k * 3]), (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]), (0, 0, 255), 2, cv2.LINE_AA) # horizontal grid cv2.line(watch_images[i], (cp[0, j * 3 + k], cp[1, j * 3 + k]), (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]), (0, 0, 255), 2, cv2.LINE_AA) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5), opts=opts)
def train_fn_dual(epoch, model, loss_fn, optimizer, dataloader, triple_generation, use_cuda=True, log_interval=100, vis=None, show=False): """ Train the model with synthetically training triple: {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL. 1. Train the transformation parameters theta_st from source image to target image; 2. Train the transformation parameters theta_tr from target image to refer image; 3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between theta and theta_GT. """ geoTnf = ComposedGeometricTnf(use_cuda=use_cuda) affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) epoch_loss = 0 if (epoch % 5 == 0 or epoch == 1) and vis is not None: stride_images = len(dataloader) / 3 group_size = 6 watch_images = torch.ones(group_size * 4, 3, 260, 240).cuda() # watch_images = torch.ones(group_size * 20, 3, 260, 240).cuda() image_names = list() fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda() stride_loss = len(dataloader) / 105 iter_loss = np.zeros(106) begin = time.time() for batch_idx, batch in enumerate(dataloader): ''' Move input batch to gpu ''' # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240) # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv) if use_cuda: batch = batch_cuda(batch) ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}''' batch_triple = triple_generation(batch) ''' Train the model ''' optimizer.zero_grad() # Predict tps parameters between images # theta.shape: (batch_size, 18) for tps batch_st = {'source_image': batch_triple['source_image'], 'target_image': batch_triple['target_image']} batch_tr = {'source_image': batch_triple['target_image'], 'target_image': batch_triple['refer_image']} theta_aff_tps_st, theta_aff_st = model(batch_st) # from source image to target image theta_aff_tps_tr, theta_aff_tr = model(batch_tr) # from target image to refer image # show_images(batch_st=batch_st, batch_tr=batch_tr, box_s=box_s, box_t=box_t, box_r=box_r) # loss = loss_fn(theta_st=theta_aff_tps_st, theta_tr=theta_aff_tps_tr, theta_GT=batch_triple['theta_GT']) loss = loss_fn(theta_aff_tps_st=theta_aff_tps_st, theta_aff_st=theta_aff_st, theta_aff_tps_tr=theta_aff_tps_tr, theta_aff_tr=theta_aff_tr, theta_aff_GT=batch_triple['theta_aff_GT'], theta_tps_GT=batch_triple['theta_tps_GT']) loss.backward() optimizer.step() epoch_loss += loss.item() if (batch_idx+1) % log_interval == 0: end = time.time() print('Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s' .format(epoch, batch_idx+1, len(dataloader), (batch_idx+1) / len(dataloader), loss.item(), batch_idx + 1, end - begin)) if (epoch % 5 == 0 or epoch == 1) and vis is not None: if (batch_idx + 1) % stride_images == 0 or batch_idx == 0: watch_images, image_names = add_watch(watch_images, image_names, batch_st, batch_tr, affTnf, tpsTnf, geoTnf, theta_aff_tps_st, theta_aff_st, theta_aff_tps_tr, theta_aff_tr, int((batch_idx + 1) / stride_images) * group_size) # if batch_idx <= 19: # watch_images, image_names = add_watch(watch_images, image_names, batch_st, batch_tr, affTnf, tpsTnf, geoTnf, theta_aff_tps_st, theta_aff_st, theta_aff_tps_tr, theta_aff_tr, batch_idx * group_size, batch_triple) # if batch_idx == 19: # opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_st target warped_tr refer warped_sr') # watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) # watch_images *= 255.0 # watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) # for i in range(len(image_names)): # cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1) # watch_images = torch.Tensor(watch_images.astype(np.float32)) # watch_images = watch_images.permute(0, 3, 1, 2) # vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts) if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0: iter_loss[int((batch_idx + 1) / stride_loss)] = epoch_loss / (batch_idx + 1) # tmp_images = normalize_image(batch_tr['target_image'], forward=False) * 255.0 # vis.images(tmp_images, nrow=8, padding=3) # if show: # if dual: # warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st_1) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1) # warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr_1) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) # # warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr) # warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) # # warped_image_aff_tps = tpsTnf(batch_st['source_image'], theta_aff_tps_st) # warped_image_aff_tps_2 = tpsTnf(batch_tr['source_image'], theta_aff_tps_tr) # warped_image_aff_tps_3 = tpsTnf(warped_image_aff_tps, theta_aff_tps_tr) # show_images(batch_st, batch_tr, warped_image_aff_tps.detach(), warped_image_aff_tps_2.detach(), warped_image_aff_tps_3.detach()) # # warped_image = affTnf(batch_st['source_image'], theta_aff_st_1) # warped_image = affTnf(warped_image, theta_aff_st) # warped_image = tpsTnf(warped_image, theta_aff_tps_st) # warped_image_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1) # warped_image_2 = affTnf(warped_image_2, theta_aff_tr) # warped_image_2 = tpsTnf(warped_image_2, theta_aff_tps_tr) # warped_image_3 = affTnf(warped_image, theta_aff_tr_1) # warped_image_3 = affTnf(warped_image_3, theta_aff_tr) # warped_image_3 = tpsTnf(warped_image_3, theta_aff_tps_tr) # show_images(batch_st, batch_tr, warped_image.detach(), warped_image_2.detach(), warped_image_3.detach()) # else: # warped_image_aff = affTnf(batch_st['source_image'], theta_st) # warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_tr) # warped_image_aff_3 = affTnf(warped_image_aff, theta_tr) # show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach()) # warped_image_tps = tpsTnf(batch_st['source_image'], theta_st) # warped_image_tps_2 = tpsTnf(batch_tr['source_image'], theta_tr) # warped_image_tps_3 = tpsTnf(warped_image_tps, theta_tr) # show_images(batch_st, batch_tr, warped_image_tps.detach(), warped_image_tps_2.detach(), warped_image_tps_3.detach()) end = time.time() # Visualize watch images & train loss if (epoch % 5 == 0 or epoch == 1) and vis is not None: opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_st target warped_tr refer warped_sr') watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1) watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')', ylabel='Loss', title='GM ResNet101 AffTPS Training Loss in Epoch ' + str(epoch), legend=['Loss'], width=2000) vis.line(iter_loss, np.arange(106), opts=opts_loss) epoch_loss /= len(dataloader) print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format(epoch_loss, end - begin)) return epoch_loss, end - begin
def vis_fn_detect(vis, model, faster_rcnn, aff_theta, train_loss, val_pck, train_lr, epoch, num_epochs, dataloader, use_cuda=True): # Visualize watch images affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) watch_images = torch.Tensor(len(dataloader) * 5, 3, 240, 240) # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) box_info_s = faster_rcnn(im_data=batch['source_im'], im_info=batch['source_im_info'][:, 3:], gt_boxes=batch['source_gt_boxes'], num_boxes=batch['source_num_boxes'])[0:3] box_info_t = faster_rcnn(im_data=batch['target_im'], im_info=batch['target_im_info'][:, 3:], gt_boxes=batch['target_gt_boxes'], num_boxes=batch['target_num_boxes'])[0:3] all_box_s = select_boxes(rois=box_info_s[0], cls_prob=box_info_s[1], bbox_pred=box_info_s[2], im_infos=batch['source_im_info'][:, 3:]) all_box_t = select_boxes(rois=box_info_t[0], cls_prob=box_info_t[1], bbox_pred=box_info_t[2], im_infos=batch['target_im_info'][:, 3:]) box_s, box_t = select_box_st(all_box_s, all_box_t) theta_det = aff_theta(boxes_s=box_s, boxes_t=box_t) theta_aff_tps, theta_aff = model(batch, theta_det) warped_image_1 = affTnf(batch['source_image'], theta_det) warped_image_2 = affTnf(warped_image_1, theta_aff) warped_image_3 = tpsTnf(warped_image_2, theta_aff_tps) watch_images[batch_idx * 5] = batch['source_image'][0] watch_images[batch_idx * 5 + 1] = warped_image_1[0] watch_images[batch_idx * 5 + 2] = warped_image_2[0] watch_images[batch_idx * 5 + 3] = warped_image_3[0] watch_images[batch_idx * 5 + 4] = batch['target_image'][0] opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped target') # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) watch_images = normalize_image(watch_images, forward=False) * 255.0 vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5), opts=opts) # vis.images(watch_images, nrow=5, padding=3, opts=opts) if epoch == num_epochs: epochs = np.arange(1, num_epochs + 1) # Visualize train loss opts_loss = dict(xlabel='Epoch', ylabel='Loss', title='GM ResNet101 Detect&Affine&TPS Training Loss', legend=['Loss'], width=2000) vis.line(train_loss, epochs, opts=opts_loss) # Visualize val pck opts_pck = dict(xlabel='Epoch', ylabel='Val PCK', title='GM ResNet101 Detect&Affine&TPS Val PCK', legend=['PCK'], width=2000) vis.line(val_pck, epochs, opts=opts_pck) # Visualize train lr opts_lr = dict( xlabel='Epoch', ylabel='Learning Rate', title='GM ResNet101 Detect&Affine&TPS Training Learning Rate', legend=['LR'], width=2000) vis.line(train_lr, epochs, opts=opts_lr)
def __call__(self, batch=None): """ Generate a synthetically training triple: 1. Use the given image pair as the source image and target image; 2. Padding the source image, and wrap the padding image with the given transformation to generate the refer image; 3. The training triple consists of {source image, target image, refer image, theta_GT} 4. The input for fasterRCNN consists of {source_im, target_im, refer_im, source_im_info, target_im_info, refer_im_info, source_gt_boxes, target_gt_boxes, refer_gt_boxes, source_num_boxes, target_num_boxes, refer_num_boxes}. """ # image_batch.shape: (batch_size, 3, H, W) # theta_batch.shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv) # boxes.shape: (batch_size, 4), 4: (x_min, y_min, x_max, y_max) img_A_batch = batch['source_image'] img_B_batch = batch['target_image'] theta_aff = batch['theta_GT'][:, :6].contiguous() theta_tps = batch['theta_GT'][:, 6:] # Generate symmetrically padded image for bigger sampling region to warp the source image padded_image_batch = self.symmetricImagePad( image_batch=img_A_batch, padding_factor=self.padding_factor) img_A_batch.requires_grad = False img_B_batch.requires_grad = False padded_image_batch.requires_grad = False theta_aff.requires_grad = False theta_tps.requires_grad = False # Get the refer image by warping the padded image with the given transformation # warped_image_batch = self.affTnf(image_batch=padded_image_batch, theta_batch=theta_aff, padding_factor=0.5, crop_factor=self.crop_factor) # warped_image_batch = self.symmetricImagePad(image_batch=warped_image_batch, padding_factor=self.padding_factor) # warped_image_batch = self.tpsTnf(image_batch=warped_image_batch, theta_batch=theta_tps, padding_factor=0.5, crop_factor=self.crop_factor) # warped_image_aff = self.affTnf(image_batch=padded_image_batch, theta_batch=theta_aff, padding_factor=self.padding_factor, crop_factor=self.crop_factor) # warped_image_tps = self.tpsTnf(image_batch=padded_image_batch, theta_batch=theta_tps, padding_factor=self.padding_factor, crop_factor=self.crop_factor) # warped_image_batch = self.affTpsTnf(source_image=padded_image_batch, theta_aff=theta_aff, theta_aff_tps=theta_tps, use_cuda=self.use_cuda) warped_image_batch = self.geometricTnf(image_batch=padded_image_batch, theta_aff=theta_aff, theta_aff_tps=theta_tps) # Get the refer im for extracting rois tmp_image_batch = normalize_image(image=warped_image_batch, forward=False) * 255.0 tmp_image_batch = tmp_image_batch.cpu().numpy().transpose((0, 2, 3, 1)) tmp_image_batch = tmp_image_batch[:, :, :, ::-1] # RGB -> BGR warped_im_batch = torch.zeros_like(warped_image_batch, dtype=torch.float32) for i in range(tmp_image_batch.shape[0]): warped_im_batch[i] = roi_data(tmp_image_batch[i])[0] warped_im_batch = warped_im_batch.cuda() warped_im_batch.requires_grad = False # img_A_batch.shape, img_B_batch.shape, warped_image_batch.shape: (batch_size, 3, 240, 240) # theta_batch.shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv) # theta_batch.shape-affine: (batch_size, 2, 3) # return {'source_image': img_A_batch, 'target_image': img_B_batch, 'refer_image': warped_image_batch, # 'source_im': batch['source_im'], 'target_im': batch['target_im'], 'refer_im': warped_im_batch, # 'source_im_info': batch['source_im_info'], 'target_im_info': batch['target_im_info'], 'refer_im_info': batch['source_im_info'], # 'source_gt_boxes': batch['source_gt_boxes'], 'target_gt_boxes': batch['target_gt_boxes'], 'refer_gt_boxes': batch['source_gt_boxes'], # 'source_num_boxes': batch['source_num_boxes'], 'target_num_boxes': batch['target_num_boxes'], 'refer_num_boxes': batch['source_num_boxes'], # 'theta_aff_GT': theta_aff, 'theta_tps_GT': theta_tps} return { 'source_image': img_A_batch, 'target_image': img_B_batch, 'refer_image': warped_image_batch, 'source_im_info': batch['source_im_info'], 'target_im_info': batch['target_im_info'], 'refer_im_info': batch['source_im_info'], 'theta_aff_GT': theta_aff, 'theta_tps_GT': theta_tps }
def vis_tss(vis, dataloader, theta, theta_weak, csv_file, title, use_cuda=True): # Visualize watch images dataframe = pd.read_csv(csv_file) scores_det = dataframe.iloc[:, 5] scores_det_aff = dataframe.iloc[:, 6] scores_det_aff_tps = dataframe.iloc[:, 7] scores_aff = dataframe.iloc[:, 8] scores_aff_tps = dataframe.iloc[:, 9] affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) watch_images = torch.ones(len(dataloader) * 8, 3, 280, 240) if use_cuda: watch_images = watch_images.cuda() image_names = list() flow = list() fnt = cv2.FONT_HERSHEY_COMPLEX # means for normalize of caffe resnet and vgg # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)) for batch_idx, batch in enumerate(dataloader): if use_cuda: batch = batch_cuda(batch) theta_det = theta['det'][batch_idx].unsqueeze(0) theta_det_aff = theta['det_aff'][batch_idx].unsqueeze(0) theta_det_aff_tps = theta['det_aff_tps'][batch_idx].unsqueeze(0) theta_aff = theta_weak['aff'][batch_idx].unsqueeze(0) theta_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0) # Warped image warped_det = affTnf(batch['source_image'], theta_det) warped_det_aff = affTnf(warped_det, theta_det_aff) warped_det_aff_tps = tpsTnf(warped_det_aff, theta_det_aff_tps) warped_aff = affTnf(batch['source_image'], theta_aff) warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps) watch_images[batch_idx * 8, :, 0:240, :] = batch['source_image'] watch_images[batch_idx * 8 + 1, :, 0:240, :] = warped_det watch_images[batch_idx * 8 + 2, :, 0:240, :] = warped_det_aff watch_images[batch_idx * 8 + 3, :, 0:240, :] = warped_det_aff_tps watch_images[batch_idx * 8 + 4, :, 0:240, :] = batch['target_image'] watch_images[batch_idx * 8 + 6, :, 0:240, :] = warped_aff watch_images[batch_idx * 8 + 7, :, 0:240, :] = warped_aff_tps image_names.append('Source') image_names.append('Det') image_names.append('Det_aff') image_names.append('Det_aff_tps') image_names.append('Target') image_names.append('') image_names.append('Rocco_aff') image_names.append('Rocco_aff_tps') flow.append('') flow.append('Flow: {:.3f}'.format(scores_det[batch_idx])) flow.append('Flow: {:.3f}'.format(scores_det_aff[batch_idx])) flow.append('Flow: {:.3f}'.format(scores_det_aff_tps[batch_idx])) flow.append('') flow.append('') flow.append('Flow: {:.3f}'.format(scores_aff[batch_idx])) flow.append('Flow: {:.3f}'.format(scores_aff_tps[batch_idx])) opts = dict(jpgquality=100, title=title) # Un-normalize for caffe resnet and vgg # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2) # watch_images = normalize_image(watch_images, forward=False) * 255.0 watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False) watch_images *= 255.0 watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) for i in range(watch_images.shape[0]): if (i + 1) % 8 != 6: pos_name = (80, 255) if (i + 1) % 8 == 1 or (i + 1) % 8 == 5: pos_lt_ac = (0, 0) pos_flow = (0, 0) else: pos_flow = (70, 275) cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1) cv2.putText(watch_images[i], flow[i], pos_flow, fnt, 0.5, (0, 0, 0), 1) else: watch_images[i] = 255 watch_images = torch.Tensor(watch_images.astype(np.float32)) watch_images = watch_images.permute(0, 3, 1, 2) vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=5), opts=opts)