def find_non_stationary_clusters(args): if args['use_gpu']: print("Using CUDA" if torch.cuda.is_available() else "Using CPU") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: device = "cpu" network_folder_name = args['folder'].split('/')[-1] tmp = re.search(r"cn(\d+)-", network_folder_name) n_clusters = int(tmp.group(1)) save_folder = os.path.join(args['dest_root'], network_folder_name) if os.path.exists( os.path.join(save_folder, 'cluster_histogram_for_corr.npy')): print('{} already exists. skipping'.format( os.path.join(save_folder, 'cluster_histogram_for_corr.npy'))) return check_mkdir(save_folder) with open(os.path.join(args['folder'], 'bestval.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=n_clusters, for_clustering=False, output_features=False, use_original_base=args['use_original_base']).to(device) net.load_state_dict( torch.load(os.path.join(args['folder'], bestval['snapshot'] + '.pth'))) # load weights net.eval() # copy network file to save location copyfile(os.path.join(args['folder'], bestval['snapshot'] + '.pth'), os.path.join(save_folder, 'weights.pth')) if args['only_copy_weights']: print('Only copying weights') return # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform if args['corr_set'] == 'both': corr_set_val1 = correspondences.Correspondences( corr_set_config1.correspondence_path, corr_set_config1.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config1.correspondence_val_list_file) corr_set_val2 = correspondences.Correspondences( corr_set_config2.correspondence_path, corr_set_config2.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config2.correspondence_val_list_file) corr_set_val = merged.Merged([corr_set_val1, corr_set_val2]) else: corr_set_val = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config.correspondence_val_list_file) # Segmentor segmentor = Segmentor(net, n_clusters, n_slices_per_pass=4) # save args open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') cluster_histogram_for_correspondences = np.zeros((n_clusters, ), dtype=np.int64) cluster_histogram_non_correspondences = np.zeros((n_clusters, ), dtype=np.int64) for i in range(0, len(corr_set_val), args['step']): img1, img2, pts1, pts2, _ = corr_set_val[i] seg1 = segmentor.run_and_save( img1, None, pre_sliding_crop_transform=pre_validation_transform, input_transform=input_transform, sliding_crop=sliding_crop_im, use_gpu=args['use_gpu']) seg1 = np.array(seg1) corr_loc_mask = np.zeros(seg1.shape, dtype=np.bool) valid_inds = (pts1[0, :] >= 0) & (pts1[0, :] < seg1.shape[1]) & ( pts1[1, :] >= 0) & (pts1[1, :] < seg1.shape[0]) pts1 = pts1[:, valid_inds] for j in range(pts1.shape[1]): pt = pts1[:, j] corr_loc_mask[pt[1], pt[0]] = True cluster_ids_corr = seg1[corr_loc_mask] hist_tmp, _ = np.histogram(cluster_ids_corr, np.arange(n_clusters + 1)) cluster_histogram_for_correspondences += hist_tmp cluster_ids_no_corr = seg1[~corr_loc_mask] hist_tmp, _ = np.histogram(cluster_ids_no_corr, np.arange(n_clusters + 1)) cluster_histogram_non_correspondences += hist_tmp if ((i + 1) % 100) < args['step']: print('{}/{}'.format(i + 1, len(corr_set_val))) np.save(os.path.join(save_folder, 'cluster_histogram_for_corr.npy'), cluster_histogram_for_correspondences) np.save(os.path.join(save_folder, 'cluster_histogram_non_corr.npy'), cluster_histogram_non_correspondences) frac = cluster_histogram_for_correspondences / \ (cluster_histogram_for_correspondences + cluster_histogram_non_correspondences) stationary_inds = np.argwhere(frac > 0.01) np.save(os.path.join(save_folder, 'stationary_inds.npy'), stationary_inds) print('{} stationary clusters out of {}'.format( len(stationary_inds), len(cluster_histogram_for_correspondences)))
def extract_features_for_reference(net, net_config, im_list_file, im_root, ref_pts_dir, max_num_features_per_image=500, fraction_correspondeces=0.5): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") features = [] # store mapping from original image representation to linear feature index # for correspondence coordinate with index i in image with index f # feature_ind = mapping[f][i][0] # x_coord = mapping[f][i][1] # y_coord = mapping[f][i][2] # # features[feature_ind] was taken from output_feature_map[:, y_coord, x_coord] of image f mapping = [] # data loading input_transform = net_config.input_transform pre_inference_transform_with_corrs = net_config.pre_inference_transform_with_corrs # make sure crop size and stride same as during training sliding_crop = joint_transforms.SlidingCropImageOnly(713, 2/3) t0 = time.time() # get all file names filenames_ims = list() filenames_pts = list() for im_l, im_r, ref_pts_d in zip(im_list_file, im_root, ref_pts_dir): with open(im_l) as f: for line in f: filename_im = line.strip() filenames_ims.append(os.path.join(im_r, filename_im)) filenames_pts.append(os.path.join( ref_pts_d, im_to_ext_name(filename_im, 'h5'))) # filenames_ims = [filenames_ims[0]] # filenames_pts = [filenames_pts[0]] iii = 0 feature_count = 0 for filename_im, filename_pts in zip(filenames_ims, filenames_pts): img = Image.open(filename_im).convert('RGB') f = h5py.File(filename_pts, 'r') pts = np.swapaxes(np.array(f['corr_pts']), 0, 1) # creating sliding crop windows and transform them if pre_inference_transform_with_corrs is not None: img, pts = pre_inference_transform_with_corrs(img, pts) img_size = img.size img_slices, slices_info = sliding_crop(img) if input_transform is not None: img_slices = [input_transform(e) for e in img_slices] img_slices = torch.stack(img_slices, 0) slices_info = torch.LongTensor(slices_info) slices_info.squeeze_(0) # run network on all slizes img_slices = img_slices.to(device) for ind in range(0, img_slices.size(0), 10): max_ind = min(ind + 10, img_slices.size(0)) with torch.no_grad(): out_tmp = net(img_slices[ind:max_ind, :, :, :]) if ind == 0: n_features = out_tmp.size(1) oh = out_tmp.size(2) ow = out_tmp.size(3) scale_h = 713/oh scale_w = 713/ow output_slices = torch.zeros( img_slices.size(0), n_features, oh, ow).to(device) output_slices[ind:max_ind] = out_tmp outsizeh = round(slices_info[:, 0].max().item()/scale_h) + oh outsizew = round(slices_info[:, 2].max().item()/scale_w) + ow count = torch.zeros(outsizeh, outsizew).to(device) output = torch.zeros(n_features, outsizeh, outsizew).to(device) sliding_transform_step = (2/3, 2/3) interpol_weight = create_interpol_weights( (oh, ow), sliding_transform_step) interpol_weight = interpol_weight.to(device) for output_slice, info in zip(output_slices, slices_info): hs = round(info[0].item()/scale_h) ws = round(info[2].item()/scale_w) output[:, hs:hs+oh, ws:ws + ow] += (interpol_weight*output_slice[:, :oh, :ow]).data count[hs:hs+oh, ws:ws+ow] += interpol_weight output /= count # Scale correspondences coordinates to output size pts[0, :] = pts[0, :]*outsizew/img_size[0] pts[1, :] = pts[1, :]*outsizeh/img_size[1] pts = (pts + 0.5).astype(int) valid_inds = (pts[0, :] < output.size(2)) & (pts[0, :] >= 0) & ( pts[1, :] < output.size(1)) & (pts[1, :] >= 0) pts = pts[:, valid_inds] pts = np.unique(pts, axis=1) if pts.shape[1] > int(max_num_features_per_image*fraction_correspondeces + 0.5): inds = np.random.choice(pts.shape[1], int( max_num_features_per_image*fraction_correspondeces + 0.5), replace=False) pts = pts[:, inds] # add some random points as well n_non_corr = int((1 - fraction_correspondeces) / fraction_correspondeces * pts.shape[1] + 0.5) lin_inds = np.random.choice(output.size( 2)*output.size(1), n_non_corr, replace=False) ptstoadd = np.concatenate((np.expand_dims( lin_inds // output.size(1), axis=0), np.expand_dims(lin_inds % output.size(1), axis=0)), axis=0) pts = np.hstack((pts, ptstoadd)) # Save features thismap = {} thismap['coords'] = [] thismap['size'] = (outsizeh, outsizew) for pti in range(pts.shape[1]): features.append( output[:, pts[1, pti], pts[0, pti]].cpu().numpy().astype(np.float32)) thismap['coords'].append([feature_count, pts[0, pti], pts[1, pti]]) feature_count += 1 mapping.append(thismap) if (iii % 100) == 0: print('computing features: %d/%d' % (iii, len(filenames_ims))) iii += 1 tend = time.time() print('Time: %f' % (tend-t0)) return features, mapping
def train_with_correspondences(save_folder, startnet, args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network().to(device) if args['snapshot'] == 'latest': args['snapshot'] = get_latest_network_name(save_folder) if len(args['snapshot']) == 0: # If start from beginning state_dict = torch.load(startnet) # needed since we slightly changed the structure of the network in # pspnet state_dict = rename_keys_to_match(state_dict) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: # If continue training print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(save_folder, args['snapshot']))) # load weights split_snapshot = args['snapshot'].split('_') start_iter = int(split_snapshot[1]) with open(os.path.join(save_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() args['best_record'] = eval(best_val_dict_str.rstrip()) net.train() freeze_bn(net) # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform target_transform = extended_transforms.MaskToTensor() train_joint_transform_seg = joint_transforms.Compose([ joint_transforms.Resize(1024), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomCrop(713) ]) train_joint_transform_corr = corr_transforms.Compose([ corr_transforms.CorrResize(1024), corr_transforms.CorrRandomCrop(713) ]) # keep list of segmentation loaders and validators seg_loaders = list() validators = list() # Correspondences corr_set = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), mean_std=model_config.mean_std, input_transform=input_transform, joint_transform=train_joint_transform_corr) corr_loader = DataLoader(corr_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) # Cityscapes Training c_config = data_configs.CityscapesConfig() seg_set_cs = cityscapes.CityScapes( c_config.train_im_folder, c_config.train_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_cs = DataLoader(seg_set_cs, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_cs) # Cityscapes Validation val_set_cs = cityscapes.CityScapes( c_config.val_im_folder, c_config.val_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_cs = DataLoader(val_set_cs, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_cs = Validator(val_loader_cs, n_classes=c_config.n_classes, save_snapshot=False, extra_name_str='Cityscapes') validators.append(validator_cs) # Vistas Training and Validation if args['include_vistas']: v_config = data_configs.VistasConfig( use_subsampled_validation_set=True, use_cityscapes_classes=True) seg_set_vis = cityscapes.CityScapes( v_config.train_im_folder, v_config.train_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_vis = DataLoader(seg_set_vis, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_vis) val_set_vis = cityscapes.CityScapes( v_config.val_im_folder, v_config.val_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_vis = DataLoader(val_set_vis, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_vis = Validator(val_loader_vis, n_classes=v_config.n_classes, save_snapshot=False, extra_name_str='Vistas') validators.append(validator_vis) else: seg_loader_vis = None map_validator = None # Extra Training extra_seg_set = cityscapes.CityScapes( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) extra_seg_loader = DataLoader(extra_seg_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(extra_seg_loader) # Extra Validation extra_val_set = cityscapes.CityScapes( corr_set_config.val_im_folder, corr_set_config.val_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) extra_val_loader = DataLoader(extra_val_set, batch_size=1, num_workers=args['n_workers'], shuffle=False) extra_validator = Validator(extra_val_loader, n_classes=corr_set_config.n_classes, save_snapshot=True, extra_name_str='Extra') validators.append(extra_validator) # Loss setup if args['corr_loss_type'] == 'class': corr_loss_fct = CorrClassLoss(input_size=[713, 713]) else: corr_loss_fct = FeatureLoss( input_size=[713, 713], loss_type=args['corr_loss_type'], feat_dist_threshold_match=args['feat_dist_threshold_match'], feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'], n_not_matching=0) seg_loss_fct = torch.nn.CrossEntropyLoss( reduction='elementwise_mean', ignore_index=cityscapes.ignore_label).to(device) # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(save_folder, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') if len(args['snapshot']) == 0: f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) else: clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) ########################################################################## # # MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS # ########################################################################## softm = torch.nn.Softmax2d() val_iter = 0 train_corr_loss = AverageMeter() train_seg_cs_loss = AverageMeter() train_seg_extra_loss = AverageMeter() train_seg_vis_loss = AverageMeter() seg_loss_meters = list() seg_loss_meters.append(train_seg_cs_loss) if args['include_vistas']: seg_loss_meters.append(train_seg_vis_loss) seg_loss_meters.append(train_seg_extra_loss) curr_iter = start_iter for i in range(args['max_iter']): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] ####################################################################### # SEGMENTATION UPDATE STEP ####################################################################### # for si, seg_loader in enumerate(seg_loaders): # get segmentation training sample inputs, gts = next(iter(seg_loader)) slice_batch_pixel_size = inputs.size(0) * inputs.size( 2) * inputs.size(3) inputs = inputs.to(device) gts = gts.to(device) optimizer.zero_grad() outputs, aux = net(inputs) main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts) aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts) loss = main_loss + 0.4 * aux_loss loss.backward() optimizer.step() seg_loss_meters[si].update(main_loss.item(), slice_batch_pixel_size) ####################################################################### # CORRESPONDENCE UPDATE STEP ####################################################################### if args['corr_loss_weight'] > 0 and args[ 'n_iterations_before_corr_loss'] < curr_iter: img_ref, img_other, pts_ref, pts_other, weights = next( iter(corr_loader)) # Transfer data to device # img_ref is from the "good" sequence with generally better # segmentation results img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] weights = [w.to(device) for w in weights] # Forward pass if args['corr_loss_type'] == 'hingeF': # Works on features net.output_all = True with torch.no_grad(): output_feat_ref, aux_feat_ref, output_ref, aux_ref = net( img_ref) output_feat_other, aux_feat_other, output_other, aux_other = net( img_other ) # output1 must be last to backpropagate derivative correctly net.output_all = False else: # Works on class probs with torch.no_grad(): output_ref, aux_ref = net(img_ref) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_ref = softm(output_ref) aux_ref = softm(aux_ref) # output1 must be last to backpropagate derivative correctly output_other, aux_other = net(img_other) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_other = softm(output_other) aux_other = softm(aux_other) # Correspondence filtering pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample( output_ref, output_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig) if b.item() > 0 ] pts_other_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig) if b.item() > 0 ] weights_orig = [ p for b, p in zip(batch_inds_to_keep_orig, weights_orig) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed output_vals_ref = output_feat_ref[batch_inds_to_keep_orig] output_vals_other = output_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed output_vals_ref = output_ref[batch_inds_to_keep_orig] output_vals_other = output_other[batch_inds_to_keep_orig] pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample( aux_ref, aux_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux) if b.item() > 0 ] pts_other_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux) if b.item() > 0 ] weights_aux = [ p for b, p in zip(batch_inds_to_keep_aux, weights_aux) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig] aux_vals_other = aux_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed aux_vals_ref = aux_ref[batch_inds_to_keep_aux] aux_vals_other = aux_other[batch_inds_to_keep_aux] optimizer.zero_grad() # correspondence loss if output_vals_ref.size(0) > 0: loss_corr_hr = corr_loss_fct(output_vals_ref, output_vals_other, pts_ref_orig, pts_other_orig, weights_orig) else: loss_corr_hr = 0 * output_vals_other.sum() if aux_vals_ref.size(0) > 0: loss_corr_aux = corr_loss_fct( aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux, weights_aux) # use output from img1 as "reference" else: loss_corr_aux = 0 * aux_vals_other.sum() loss_corr = args['corr_loss_weight'] * \ (loss_corr_hr + 0.4 * loss_corr_aux) loss_corr.backward() optimizer.step() train_corr_loss.update(loss_corr.item()) ####################################################################### # LOGGING ETC ####################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg, curr_iter) writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (i + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % ( curr_iter, len(corr_loader), train_corr_loss.avg, train_seg_cs_loss.avg, train_seg_vis_loss.avg, train_seg_extra_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if val_iter >= args['val_interval']: val_iter = 0 for validator in validators: validator.run(net, optimizer, args, curr_iter, save_folder, f_handle, writer=writer) # Post training f_handle.close() writer.close()
def segment_images_in_folder(network_file, img_folder, save_folder, args): # get current available device if args['use_gpu']: print("Using CUDA" if torch.cuda.is_available() else "Using CPU") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: device = "cpu" # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() if 'n_classes' in args: print('Initializing model with %d classes' % args['n_classes']) net = model_config.init_network( n_classes=args['n_classes'], for_clustering=False, output_features=False, use_original_base=args['use_original_base']).to(device) else: net = model_config.init_network().to(device) print('load model ' + network_file) state_dict = torch.load(network_file, map_location=lambda storage, loc: storage) # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) net.load_state_dict(state_dict) net.eval() # data loading input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform # make sure crop size and stride same as during training sliding_crop = joint_transforms.SlidingCropImageOnly( 713, args['sliding_transform_step']) check_mkdir(save_folder) t0 = time.time() # get all file names filenames_ims = list() filenames_segs = list() print('Scanning %s for images to segment.' % img_folder) for root, subdirs, files in os.walk(img_folder): filenames = [f for f in files if f.endswith(args['img_ext'])] if len(filenames) > 0: print('Found %d images in %s' % (len(filenames), root)) seg_path = root.replace(img_folder, save_folder) check_mkdir(seg_path) filenames_ims += [os.path.join(root, f) for f in filenames] filenames_segs += [ os.path.join(seg_path, f.replace(args['img_ext'], '.png')) for f in filenames ] # Create segmentor if net.n_classes == 19: # This could be the 19 cityscapes classes segmentor = Segmentor(net, net.n_classes, colorize_fcn=cityscapes.colorize_mask, n_slices_per_pass=args['n_slices_per_pass']) else: segmentor = Segmentor(net, net.n_classes, colorize_fcn=None, n_slices_per_pass=args['n_slices_per_pass']) count = 1 for im_file, save_path in zip(filenames_ims, filenames_segs): tnow = time.time() print("[%d/%d (%.1fs/%.1fs)] %s" % (count, len(filenames_ims), tnow - t0, (tnow - t0) / count * len(filenames_ims), im_file)) segmentor.run_and_save( im_file, save_path, pre_sliding_crop_transform=pre_validation_transform, sliding_crop=sliding_crop, input_transform=input_transform, skip_if_seg_exists=True, use_gpu=args['use_gpu']) count += 1 tend = time.time() print('Time: %f' % (tend - t0))
def run_and_save( self, img_path, save_path, pre_sliding_crop_transform=None, sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.), input_transform=standard_transforms.ToTensor(), verbose=False, skip_if_seg_exists=False, use_gpu=True, ): """ img - Path of input image save_path - Path of output image (feature map) sliding_crop - Transform that returns set of image slices input_transform - Transform to apply to image before inputting to network skip_if_seg_exists - Whether to overwrite or skip if segmentation exists already """ if save_path is not None: if os.path.exists(save_path): if skip_if_seg_exists: if verbose: print( "Segmentation already exists, skipping: {}".format( save_path)) return else: if verbose: print("Segmentation already exists, overwriting: {}". format(save_path)) if isinstance(img_path, str): try: img = Image.open(img_path).convert('RGB') except OSError: print( "Error reading input image, skipping: {}".format(img_path)) else: img = img_path # creating sliding crop windows and transform them img_size_orig = img.size if pre_sliding_crop_transform is not None: # might reshape image img = pre_sliding_crop_transform(img) img_slices, slices_info = sliding_crop(img) img_slices = [input_transform(e) for e in img_slices] img_slices = torch.stack(img_slices, 0) slices_info = torch.LongTensor(slices_info) slices_info.squeeze_(0) of_pre, oa_pre = self.net.output_features, self.net.output_all self.net.output_features, self.net.output_all = True, False feature_map = self.run_on_slices( img_slices, slices_info, sliding_transform_step=sliding_crop.stride_rate, use_gpu=use_gpu) # restore previous settings self.net.output_features, self.net.output_all = of_pre, oa_pre if save_path is not None: check_mkdir(os.path.dirname(save_path)) ext = save_path.split('.')[-1] if ext == 'mat': matdict = { "features": np.transpose(feature_map, [1, 2, 0]), "original_image_size": (img_size_orig[1], img_size_orig[0]) } sio.savemat(save_path, matdict, appendmat=False) elif ext == 'npy': np.save(save_path, feature_map) else: raise ValueError( 'invalid file extension for save_path, only mat and np supported' ) return feature_map
def run_and_save( self, img_path, seg_path, pre_sliding_crop_transform=None, sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.), input_transform=standard_transforms.ToTensor(), verbose=False, skip_if_seg_exists=False, use_gpu=True, ): """ img - Path of input image seg_path - Path of output image (segmentation) sliding_crop - Transform that returns set of image slices input_transform - Transform to apply to image before inputting to network skip_if_seg_exists - Whether to overwrite or skip if segmentation exists already """ if seg_path is not None: if os.path.exists(seg_path): if skip_if_seg_exists: if verbose: print( "Segmentation already exists, skipping: {}".format( seg_path)) return else: if verbose: print("Segmentation already exists, overwriting: {}". format(seg_path)) if isinstance(img_path, str): try: img = Image.open(img_path).convert('RGB') except OSError: print( "Error reading input image, skipping: {}".format(img_path)) return else: img = img_path # creating sliding crop windows and transform them img_size_orig = img.size if pre_sliding_crop_transform is not None: # might reshape image img = pre_sliding_crop_transform(img) img_slices, slices_info = sliding_crop(img) img_slices = [input_transform(e) for e in img_slices] img_slices = torch.stack(img_slices, 0) slices_info = torch.LongTensor(slices_info) slices_info.squeeze_(0) prediction_logits = self.run_on_slices( img_slices, slices_info, sliding_transform_step=sliding_crop.stride_rate, use_gpu=use_gpu, return_logits=True) prediction_orig = prediction_logits.max(0)[1].squeeze_(0).numpy() prediction_logits = prediction_logits.numpy() if self.colorize_fcn: prediction_colorized = self.colorize_fcn(prediction_orig) else: prediction_colorized = Image.fromarray( prediction_orig.astype(np.int32)).convert('I') if prediction_colorized.size != img_size_orig: prediction_colorized = F.resize(prediction_colorized, img_size_orig[::-1], interpolation=Image.NEAREST) if seg_path is not None: check_mkdir(os.path.dirname(seg_path)) prediction_colorized.save(seg_path) return prediction_colorized