def evaluate_segmented_images_for_experiments(network_folder, validation_metric, dataset): if dataset == 'cityscapes': dset = dataset_configs.CityscapesConfig() truth_folder = dset.val_seg_folder result_folder_name = 'cityscapes-val-results' elif dataset == 'wilddash': dset = dataset_configs.WilddashConfig() truth_folder = dset.val_seg_folder result_folder_name = 'wilddash_val_results' elif dataset == 'cmu': dset = dataset_configs.CmuConfig() truth_folder = dset.test_seg_folder result_folder_name = 'cmu-annotated-test-images' elif dataset == 'rc': dset = dataset_configs.RobotcarConfig() truth_folder = dset.test_seg_folder result_folder_name = 'robotcar-test-results' elif dataset == 'vistas': dset = dataset_configs.VistasConfig() truth_folder = dset.val_seg_folder result_folder_name = 'vistas-validation' if len(validation_metric) > 0: result_folder_name += '_' + validation_metric evaluate_segmented_images(os.path.join(network_folder, result_folder_name), truth_folder, dset.im_file_ending.replace('jpg', 'png'), dset.seg_file_ending, dset.id_to_trainid, n_classes=dset.n_classes)
def write_lists_for_corr(corr_set): if corr_set == 'cmu': corr_set_config = data_configs.CmuConfig() elif corr_set == 'rc': corr_set_config = data_configs.RobotcarConfig() n_samples_to_include = 1000000 fraction_training = 0.7 np.random.seed(0) f_name_list = [fn for fn in os.listdir(corr_set_config.correspondence_path) if fn.endswith('mat')] n_samples = min(n_samples_to_include, len(f_name_list)) n_training = ceil(fraction_training*n_samples) n_validation = n_samples - n_training training_ids = np.random.choice(len(f_name_list), n_training) ids_left_for_validation = set(range(len(f_name_list))) - set(training_ids) validation_ids = np.random.choice(list(ids_left_for_validation), n_validation) f_train = open(corr_set_config.correspondence_train_list_file, 'w') f_val = open(corr_set_config.correspondence_val_list_file, 'w') for i in training_ids: f_train.write(f_name_list[i] + '\n') for i in validation_ids: f_val.write(f_name_list[i] + '\n') f_train.close() f_val.close()
def write_reference_im_list(corr_set): if corr_set == 'cmu': corr_set_config = data_configs.CmuConfig() elif corr_set == 'rc': corr_set_config = data_configs.RobotcarConfig() # FILES FROM LIST f_name_list = [] with open(corr_set_config.correspondence_train_list_file) as f: for line in f: f_name_list.append(line.strip()) # other ref_file_names = set() # use set to automatically handle duplicates # run it = 0 for f_name in f_name_list: mat_content = {} ff = h5py.File( os.path.join(corr_set_config.correspondence_path, f_name), 'r') for k, v in ff.items(): mat_content[k] = np.array(v) im1name = ''.join( chr(a) for a in mat_content['im_i_path']) # convert to string ref_file_names.add(im1name) it += 1 if (it % 100) == 0: print('%d/%d' % (it, len(f_name_list))) print('Writing file %s ' % corr_set_config.reference_image_list) with open(corr_set_config.reference_image_list, 'w') as f: for fname in ref_file_names: f.write('%s\n' % fname)
def find_non_stationary_clusters(args): if args['use_gpu']: print("Using CUDA" if torch.cuda.is_available() else "Using CPU") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: device = "cpu" network_folder_name = args['folder'].split('/')[-1] tmp = re.search(r"cn(\d+)-", network_folder_name) n_clusters = int(tmp.group(1)) save_folder = os.path.join(args['dest_root'], network_folder_name) if os.path.exists( os.path.join(save_folder, 'cluster_histogram_for_corr.npy')): print('{} already exists. skipping'.format( os.path.join(save_folder, 'cluster_histogram_for_corr.npy'))) return check_mkdir(save_folder) with open(os.path.join(args['folder'], 'bestval.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=n_clusters, for_clustering=False, output_features=False, use_original_base=args['use_original_base']).to(device) net.load_state_dict( torch.load(os.path.join(args['folder'], bestval['snapshot'] + '.pth'))) # load weights net.eval() # copy network file to save location copyfile(os.path.join(args['folder'], bestval['snapshot'] + '.pth'), os.path.join(save_folder, 'weights.pth')) if args['only_copy_weights']: print('Only copying weights') return # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform if args['corr_set'] == 'both': corr_set_val1 = correspondences.Correspondences( corr_set_config1.correspondence_path, corr_set_config1.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config1.correspondence_val_list_file) corr_set_val2 = correspondences.Correspondences( corr_set_config2.correspondence_path, corr_set_config2.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config2.correspondence_val_list_file) corr_set_val = merged.Merged([corr_set_val1, corr_set_val2]) else: corr_set_val = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), input_transform=None, joint_transform=None, listfile=corr_set_config.correspondence_val_list_file) # Segmentor segmentor = Segmentor(net, n_clusters, n_slices_per_pass=4) # save args open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') cluster_histogram_for_correspondences = np.zeros((n_clusters, ), dtype=np.int64) cluster_histogram_non_correspondences = np.zeros((n_clusters, ), dtype=np.int64) for i in range(0, len(corr_set_val), args['step']): img1, img2, pts1, pts2, _ = corr_set_val[i] seg1 = segmentor.run_and_save( img1, None, pre_sliding_crop_transform=pre_validation_transform, input_transform=input_transform, sliding_crop=sliding_crop_im, use_gpu=args['use_gpu']) seg1 = np.array(seg1) corr_loc_mask = np.zeros(seg1.shape, dtype=np.bool) valid_inds = (pts1[0, :] >= 0) & (pts1[0, :] < seg1.shape[1]) & ( pts1[1, :] >= 0) & (pts1[1, :] < seg1.shape[0]) pts1 = pts1[:, valid_inds] for j in range(pts1.shape[1]): pt = pts1[:, j] corr_loc_mask[pt[1], pt[0]] = True cluster_ids_corr = seg1[corr_loc_mask] hist_tmp, _ = np.histogram(cluster_ids_corr, np.arange(n_clusters + 1)) cluster_histogram_for_correspondences += hist_tmp cluster_ids_no_corr = seg1[~corr_loc_mask] hist_tmp, _ = np.histogram(cluster_ids_no_corr, np.arange(n_clusters + 1)) cluster_histogram_non_correspondences += hist_tmp if ((i + 1) % 100) < args['step']: print('{}/{}'.format(i + 1, len(corr_set_val))) np.save(os.path.join(save_folder, 'cluster_histogram_for_corr.npy'), cluster_histogram_for_correspondences) np.save(os.path.join(save_folder, 'cluster_histogram_non_corr.npy'), cluster_histogram_non_correspondences) frac = cluster_histogram_for_correspondences / \ (cluster_histogram_for_correspondences + cluster_histogram_non_correspondences) stationary_inds = np.argwhere(frac > 0.01) np.save(os.path.join(save_folder, 'stationary_inds.npy'), stationary_inds) print('{} stationary clusters out of {}'.format( len(stationary_inds), len(cluster_histogram_for_correspondences)))
def train_with_correspondences(save_folder, startnet, args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network().to(device) if args['snapshot'] == 'latest': args['snapshot'] = get_latest_network_name(save_folder) if len(args['snapshot']) == 0: # If start from beginning state_dict = torch.load(startnet) # needed since we slightly changed the structure of the network in # pspnet state_dict = rename_keys_to_match(state_dict) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: # If continue training print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(save_folder, args['snapshot']))) # load weights split_snapshot = args['snapshot'].split('_') start_iter = int(split_snapshot[1]) with open(os.path.join(save_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() args['best_record'] = eval(best_val_dict_str.rstrip()) net.train() freeze_bn(net) # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform target_transform = extended_transforms.MaskToTensor() train_joint_transform_seg = joint_transforms.Compose([ joint_transforms.Resize(1024), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomCrop(713) ]) train_joint_transform_corr = corr_transforms.Compose([ corr_transforms.CorrResize(1024), corr_transforms.CorrRandomCrop(713) ]) # keep list of segmentation loaders and validators seg_loaders = list() validators = list() # Correspondences corr_set = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), mean_std=model_config.mean_std, input_transform=input_transform, joint_transform=train_joint_transform_corr) corr_loader = DataLoader(corr_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) # Cityscapes Training c_config = data_configs.CityscapesConfig() seg_set_cs = cityscapes.CityScapes( c_config.train_im_folder, c_config.train_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_cs = DataLoader(seg_set_cs, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_cs) # Cityscapes Validation val_set_cs = cityscapes.CityScapes( c_config.val_im_folder, c_config.val_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_cs = DataLoader(val_set_cs, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_cs = Validator(val_loader_cs, n_classes=c_config.n_classes, save_snapshot=False, extra_name_str='Cityscapes') validators.append(validator_cs) # Vistas Training and Validation if args['include_vistas']: v_config = data_configs.VistasConfig( use_subsampled_validation_set=True, use_cityscapes_classes=True) seg_set_vis = cityscapes.CityScapes( v_config.train_im_folder, v_config.train_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_vis = DataLoader(seg_set_vis, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_vis) val_set_vis = cityscapes.CityScapes( v_config.val_im_folder, v_config.val_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_vis = DataLoader(val_set_vis, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_vis = Validator(val_loader_vis, n_classes=v_config.n_classes, save_snapshot=False, extra_name_str='Vistas') validators.append(validator_vis) else: seg_loader_vis = None map_validator = None # Extra Training extra_seg_set = cityscapes.CityScapes( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) extra_seg_loader = DataLoader(extra_seg_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(extra_seg_loader) # Extra Validation extra_val_set = cityscapes.CityScapes( corr_set_config.val_im_folder, corr_set_config.val_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) extra_val_loader = DataLoader(extra_val_set, batch_size=1, num_workers=args['n_workers'], shuffle=False) extra_validator = Validator(extra_val_loader, n_classes=corr_set_config.n_classes, save_snapshot=True, extra_name_str='Extra') validators.append(extra_validator) # Loss setup if args['corr_loss_type'] == 'class': corr_loss_fct = CorrClassLoss(input_size=[713, 713]) else: corr_loss_fct = FeatureLoss( input_size=[713, 713], loss_type=args['corr_loss_type'], feat_dist_threshold_match=args['feat_dist_threshold_match'], feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'], n_not_matching=0) seg_loss_fct = torch.nn.CrossEntropyLoss( reduction='elementwise_mean', ignore_index=cityscapes.ignore_label).to(device) # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(save_folder, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') if len(args['snapshot']) == 0: f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) else: clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) ########################################################################## # # MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS # ########################################################################## softm = torch.nn.Softmax2d() val_iter = 0 train_corr_loss = AverageMeter() train_seg_cs_loss = AverageMeter() train_seg_extra_loss = AverageMeter() train_seg_vis_loss = AverageMeter() seg_loss_meters = list() seg_loss_meters.append(train_seg_cs_loss) if args['include_vistas']: seg_loss_meters.append(train_seg_vis_loss) seg_loss_meters.append(train_seg_extra_loss) curr_iter = start_iter for i in range(args['max_iter']): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] ####################################################################### # SEGMENTATION UPDATE STEP ####################################################################### # for si, seg_loader in enumerate(seg_loaders): # get segmentation training sample inputs, gts = next(iter(seg_loader)) slice_batch_pixel_size = inputs.size(0) * inputs.size( 2) * inputs.size(3) inputs = inputs.to(device) gts = gts.to(device) optimizer.zero_grad() outputs, aux = net(inputs) main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts) aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts) loss = main_loss + 0.4 * aux_loss loss.backward() optimizer.step() seg_loss_meters[si].update(main_loss.item(), slice_batch_pixel_size) ####################################################################### # CORRESPONDENCE UPDATE STEP ####################################################################### if args['corr_loss_weight'] > 0 and args[ 'n_iterations_before_corr_loss'] < curr_iter: img_ref, img_other, pts_ref, pts_other, weights = next( iter(corr_loader)) # Transfer data to device # img_ref is from the "good" sequence with generally better # segmentation results img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] weights = [w.to(device) for w in weights] # Forward pass if args['corr_loss_type'] == 'hingeF': # Works on features net.output_all = True with torch.no_grad(): output_feat_ref, aux_feat_ref, output_ref, aux_ref = net( img_ref) output_feat_other, aux_feat_other, output_other, aux_other = net( img_other ) # output1 must be last to backpropagate derivative correctly net.output_all = False else: # Works on class probs with torch.no_grad(): output_ref, aux_ref = net(img_ref) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_ref = softm(output_ref) aux_ref = softm(aux_ref) # output1 must be last to backpropagate derivative correctly output_other, aux_other = net(img_other) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_other = softm(output_other) aux_other = softm(aux_other) # Correspondence filtering pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample( output_ref, output_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig) if b.item() > 0 ] pts_other_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig) if b.item() > 0 ] weights_orig = [ p for b, p in zip(batch_inds_to_keep_orig, weights_orig) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed output_vals_ref = output_feat_ref[batch_inds_to_keep_orig] output_vals_other = output_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed output_vals_ref = output_ref[batch_inds_to_keep_orig] output_vals_other = output_other[batch_inds_to_keep_orig] pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample( aux_ref, aux_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux) if b.item() > 0 ] pts_other_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux) if b.item() > 0 ] weights_aux = [ p for b, p in zip(batch_inds_to_keep_aux, weights_aux) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig] aux_vals_other = aux_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed aux_vals_ref = aux_ref[batch_inds_to_keep_aux] aux_vals_other = aux_other[batch_inds_to_keep_aux] optimizer.zero_grad() # correspondence loss if output_vals_ref.size(0) > 0: loss_corr_hr = corr_loss_fct(output_vals_ref, output_vals_other, pts_ref_orig, pts_other_orig, weights_orig) else: loss_corr_hr = 0 * output_vals_other.sum() if aux_vals_ref.size(0) > 0: loss_corr_aux = corr_loss_fct( aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux, weights_aux) # use output from img1 as "reference" else: loss_corr_aux = 0 * aux_vals_other.sum() loss_corr = args['corr_loss_weight'] * \ (loss_corr_hr + 0.4 * loss_corr_aux) loss_corr.backward() optimizer.step() train_corr_loss.update(loss_corr.item()) ####################################################################### # LOGGING ETC ####################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg, curr_iter) writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (i + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % ( curr_iter, len(corr_loader), train_corr_loss.avg, train_seg_cs_loss.avg, train_seg_vis_loss.avg, train_seg_extra_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if val_iter >= args['val_interval']: val_iter = 0 for validator in validators: validator.run(net, optimizer, args, curr_iter, save_folder, f_handle, writer=writer) # Post training f_handle.close() writer.close()
def segment_images_in_folder_for_experiments(network_folder, args): # Predefined image sets and paths if len(args['img_set']) > 0: if args['img_set'] == 'cmu': dset = dataset_configs.CmuConfig() args['img_path'] = dset.test_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'cmu-annotated-test-images' elif args['img_set'] == 'wilddash': dset = dataset_configs.WilddashConfig() args['img_path'] = dset.val_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'wilddash_val_results' elif args['img_set'] == 'rc': dset = dataset_configs.RobotcarConfig() args['img_path'] = dset.test_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'robotcar-test-results' elif args['img_set'] == 'cityscapes': dset = dataset_configs.CityscapesConfig() args['img_path'] = dset.val_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'cityscapes-val-results' elif args['img_set'] == 'cmu-vis': dset = dataset_configs.CmuConfig() args['img_path'] = dset.vis_test_im_folder args['img_ext'] = 'jpg' args['save_folder_name'] = 'cmu-visual-test-images' elif args['img_set'] == 'rc-vis': dset = dataset_configs.RobotcarConfig() args['img_path'] = dset.vis_test_im_folder args['img_ext'] = 'jpg' args['save_folder_name'] = 'robotcar-visual-test-images' elif args['img_set'] == 'vistas': dset = dataset_configs.VistasConfig() args['img_path'] = dset.val_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'vistas-validation' if len(args['network_file']) < 1: print("Loading best network according to specified validation metric") if args['validation_metric'] == 'miou': with open(os.path.join(network_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) print("Network file %s - val miou %s" % (bestval['snapshot'], bestval['mean_iu'])) elif args['validation_metric'] == 'acc': with open(os.path.join(network_folder, 'bestval_acc.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) print("Network file %s - val acc %s" % (bestval['snapshot'], bestval['acc'])) net_to_load = bestval['snapshot'] + '.pth' network_file = os.path.join(network_folder, net_to_load) else: print("Loading specified network") slash_inds = [ i for i in range(len(args['network_file'])) if args['network_file'].startswith('/', i) ] network_folder = args['network_file'][:slash_inds[-1]] network_file = args['network_file'] # folder should have same name as for trained network if args['validation_metric'] == 'miou': save_folder = os.path.join(network_folder, args['save_folder_name']) elif args['validation_metric'] == 'acc': save_folder = os.path.join(network_folder, args['save_folder_name'] + '_acc') segment_images_in_folder(network_file, args['img_path'], save_folder, args)
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args): print(save_folder.split('/')[-1]) skip_clustering = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) check_mkdir(tmp_seg_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=args['n_clusters'], for_clustering=True, output_features=True, use_original_base=args['use_original_base']).to(device) state_dict = torch.load(startnet) if 'resnet101' in startnet: load_resnet101_weights(net, state_dict) else: # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) # different amount of classes init_last_layers(state_dict, args['n_clusters']) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss_feat': 1e10, 'val_loss_out': 1e10, 'val_loss_cluster': 1e10 } # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'pola': corr_set_config = data_configs.PolaConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() ref_image_lists = corr_set_config.reference_image_list # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True) # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}') # print(corr_set_config) # corr_im_paths = [corr_set_config.correspondence_im_path] # ref_featurs_pos = [corr_set_config.reference_feature_poitions] input_transform = model_config.input_transform #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path, # corr_set_config.correspondence_im_path, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform_corr, # listfile=corr_set_config.correspondence_train_list_file) scales = [0, 1, 2, 3] # corr_set_train = Poladata.MonoDataset(corr_set_config, # seg_folder = "media/HDD1/NsemSEG/Result_fold/" , # im_file_ending = ".jpg" ) train_joint_transform = joint_transforms.Compose([ # train_joint_transform_corr = corr_transforms.Compose([ # corr_transforms.CorrResize(1024), # corr_transforms.CorrRandomCrop(713) joint_transforms.Resize(1024), joint_transforms.RandomCrop(713) ]) sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255) # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder, # corr_set_config.train_im_folder, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform, # listfile=None) corr_set_train = Poladata.MonoDataset( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, im_file_ending=".jpg", id_to_trainid=None, joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=input_transform, target_transform=None, #train_joint_transform, transform_before_sliding=None #sliding_crop ) #print (corr_set_train) # print(corr_set_train.mask) corr_loader_train = DataLoader(corr_set_train, batch_size=1, num_workers=args['n_workers'], shuffle=True) # corr_loader_train = input_transform(corr_loader_train) # print(corr_loader_train) seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean') # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) # Clustering deepcluster = clustering.Kmeans(args['n_clusters']) if skip_clustering: deepcluster.set_index(cluster_centroids) open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) val_iter = 0 curr_iter = start_iter while curr_iter <= args['max_iter']: net.eval() net.output_features = True # max_num_features_per_image = args['max_features_per_image'] # print('-----------------------------------------------------------------') # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ') # print('-----------------------------------------------------------------') # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # features, _ = extract_features_for_reference(net, model_config, ref_image_lists, # corr_im_paths, ref_featurs_pos, # max_num_features_per_image=args['max_features_per_image'], # fraction_correspondeces=0.5) print( 'ici on a la len de la ref im list --------------------------------------------------------' ) print(len(ref_image_lists)) features = extract_features_for_reference_nocorr( net, model_config, corr_set_train, 10, max_num_features_per_image=args['max_features_per_image']) cluster_features = np.vstack(features) del features # cluster the features cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures( cluster_features, verbose=True, use_gpu=False) # save cluster centroids h5f = h5py.File( os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w') h5f.create_dataset('cluster_centroids', data=cluster_centroids) h5f.create_dataset('pca_transform_Amat', data=pca_info[0]) h5f.create_dataset('pca_transform_bvec', data=pca_info[1]) h5f.close() # Print distribution of clusters cluster_distribution, _ = np.histogram( cluster_indices, bins=np.arange(args['n_clusters'] + 1), density=True) str2write = 'cluster distribution ' + \ np.array2string(cluster_distribution, formatter={ 'float_kind': '{0:.8f}'.format}).replace('\n', ' ') print(str2write) f_handle.write(str2write + "\n") # set last layer weight to a normal distribution reinit_last_layers(net) # make a copy of current network state to do cluster assignment net_for_clustering = copy.deepcopy(net) optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] net.train() freeze_bn(net) net.output_features = False cluster_training_count = 0 # Train using the training correspondence set corr_train_loss = AverageMeter() seg_train_loss = AverageMeter() feature_train_loss = AverageMeter() while cluster_training_count < args[ 'cluster_interval'] and curr_iter <= args['max_iter']: # First extract cluster labels using saved network checkpoint print( 'on rentre dans la boucle extract cluster_______________________________________________' ) net.to("cpu") net_for_clustering.to(device) net_for_clustering.eval() net_for_clustering.output_features = True data_samples = [] extract_label_count = 0 while (extract_label_count < args['chunk_size']) and ( cluster_training_count + extract_label_count < args['cluster_interval'] ) and (val_iter + extract_label_count < args['val_interval']) and ( extract_label_count + curr_iter <= args['max_iter']): # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train)) corr_loader_train = input_transform(corr_loader_train) print( f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}' ) img_ref, img_other, pts_ref, pts_other, _ = next( iter(corr_loader_train)) # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # print(img_ref) # Transfer data to device img_ref = img_ref.to(device) with torch.no_grad(): features = net_for_clustering(img_ref) # assign feature to clusters for entire patch output = features.cpu().numpy() output_flat = output.reshape( (output.shape[0], output.shape[1], -1)) cluster_image = np.zeros( (output.shape[0], output.shape[2], output.shape[3]), dtype=np.int64) for b in range(output_flat.shape[0]): out_f = output_flat[b] out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1), pca_info=pca_info) cluster_labels = deepcluster.assign(out_f2) cluster_image[b] = cluster_labels.reshape( (output.shape[2], output.shape[3])) cluster_image = torch.from_numpy(cluster_image).to(device) # assign cluster to correspondence positions cluster_labels = assign_cluster_ids_to_correspondence_points( features, pts_ref, (deepcluster, pca_info), inds_other=pts_other, orig_im_size=(713, 713)) # Transfer data to cpu img_ref = img_ref.cpu() cluster_labels = [p.cpu() for p in cluster_labels] cluster_image = cluster_image.cpu() data_samples.append((img_ref, cluster_labels, cluster_image)) extract_label_count += 1 net_for_clustering.to("cpu") net.to(device) for data_sample in data_samples: img_ref, cluster_labels, cluster_image = data_sample # Transfer data to device img_ref = img_ref.to(device) cluster_labels = [p.to(device) for p in cluster_labels] cluster_image = cluster_image.to(device) optimizer.zero_grad() outputs_ref, aux_ref = net(img_ref) seg_main_loss = seg_loss_fct(outputs_ref, cluster_image) seg_aux_loss = seg_loss_fct(aux_ref, cluster_image) loss = args['seg_loss_weight'] * \ (seg_main_loss + 0.4 * seg_aux_loss) loss.backward() optimizer.step() cluster_training_count += 1 if type(seg_main_loss) == torch.Tensor: seg_train_loss.update(seg_main_loss.item(), 1) #################################################################################################### # LOGGING ETC #################################################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss', seg_train_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (curr_iter + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % ( curr_iter + 1, args['max_iter'], seg_train_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if curr_iter > args['max_iter']: break # Post training f_handle.close() writer.close()
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args): print(save_folder.split('/')[-1]) skip_clustering = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) check_mkdir(tmp_seg_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=args['n_clusters'], for_clustering=True, output_features=True, use_original_base=args['use_original_base']).to(device) if args['snapshot'] == 'latest': args['snapshot'] = get_latest_network_name(save_folder) if len(args['snapshot']) == 0: # If start from beginning state_dict = torch.load(startnet) if 'resnet101' in startnet: load_resnet101_weights(net, state_dict) else: # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) init_last_layers(state_dict, args['n_clusters']) # different amount of classes net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss_feat': 1e10, 'val_loss_out': 1e10, 'val_loss_cluster': 1e10 } else: # If continue training print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(save_folder, args['snapshot']))) # load weights split_snapshot = args['snapshot'].split('_') start_iter = int(split_snapshot[1]) with open(os.path.join(save_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() args['best_record'] = eval(best_val_dict_str.rstrip()) if start_iter >= args['max_iter']: return if (start_iter % args['cluster_interval']) == 0: skip_clustering = False else: skip_clustering = True last_cluster_network_snapshot_iter = ( start_iter // args['cluster_interval']) * args['cluster_interval'] # load cluster info cluster_info = {} f = h5py.File( os.path.join( save_folder, 'centroids_{}.h5'.format( last_cluster_network_snapshot_iter)), 'r') for k, v in f.items(): cluster_info[k] = np.array(v) cluster_centroids = cluster_info['cluster_centroids'] pca_info = [ cluster_info['pca_transform_Amat'], cluster_info['pca_transform_bvec'] ] # load network that was used for last clustering net_for_clustering = model_config.init_network( n_classes=args['n_clusters'], for_clustering=True, output_features=True, use_original_base=args['use_original_base']) if last_cluster_network_snapshot_iter == 0: state_dict = torch.load( startnet, map_location=lambda storage, loc: storage) if 'resnet101' in startnet: load_resnet101_weights(net_for_clustering, state_dict) else: # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) init_last_layers( state_dict, args['n_clusters']) # different amount of classes net_for_clustering.load_state_dict( state_dict) # load original weights else: cluster_network_weights = get_network_name_from_iteration( save_folder, last_cluster_network_snapshot_iter) net_for_clustering.load_state_dict( torch.load(os.path.join(save_folder, cluster_network_weights), map_location=lambda storage, loc: storage) ) # load weights # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() if args['corr_set'] == 'both': ref_image_lists = [ corr_set_config1.reference_image_list, corr_set_config2.reference_image_list ] corr_im_paths = [ corr_set_config1.correspondence_im_path, corr_set_config2.correspondence_im_path ] ref_featurs_pos = [ corr_set_config1.reference_feature_poitions, corr_set_config2.reference_feature_poitions ] else: ref_image_lists = [corr_set_config.reference_image_list] corr_im_paths = [corr_set_config.correspondence_im_path] ref_featurs_pos = [corr_set_config.reference_feature_poitions] input_transform = model_config.input_transform train_joint_transform_corr = corr_transforms.Compose([ corr_transforms.CorrResize(1024), corr_transforms.CorrRandomCrop(713) ]) # Correspondences for training if args['corr_set'] == 'both': corr_set_train1 = correspondences.Correspondences( corr_set_config1.correspondence_path, corr_set_config1.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config1.correspondence_train_list_file) corr_set_train2 = correspondences.Correspondences( corr_set_config2.correspondence_path, corr_set_config2.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config2.correspondence_train_list_file) corr_set_train = merged.Merged([corr_set_train1, corr_set_train2]) else: corr_set_train = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config.correspondence_train_list_file) corr_loader_train = DataLoader(corr_set_train, batch_size=1, num_workers=args['n_workers'], shuffle=True) # Correspondences for validation if args['corr_set'] == 'both': corr_set_val1 = correspondences.Correspondences( corr_set_config1.correspondence_path, corr_set_config1.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config1.correspondence_val_list_file) corr_set_val2 = correspondences.Correspondences( corr_set_config2.correspondence_path, corr_set_config2.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config2.correspondence_val_list_file) corr_set_val = merged.Merged([corr_set_val1, corr_set_val2]) else: corr_set_val = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), input_transform=input_transform, joint_transform=train_joint_transform_corr, listfile=corr_set_config.correspondence_val_list_file) corr_loader_val = DataLoader(corr_set_val, batch_size=1, num_workers=args['n_workers'], shuffle=False) # Loss setup val_corr_loss_fct_feat = FeatureLoss( input_size=[713, 713], loss_type=args['feature_distance_measure'], feat_dist_threshold_match=0.8, feat_dist_threshold_nomatch=0.2, n_not_matching=0) val_corr_loss_fct_out = FeatureLoss(input_size=[713, 713], loss_type='KL', feat_dist_threshold_match=0.8, feat_dist_threshold_nomatch=0.2, n_not_matching=0) loss_fct = ClusterCorrespondenceLoss(input_size=[713, 713], size_average=True).to(device) seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean') if args['feature_hinge_loss_weight'] > 0: feature_loss_fct = FeatureLoss(input_size=[713, 713], loss_type='hingeF', feat_dist_threshold_match=0.8, feat_dist_threshold_nomatch=0.2, n_not_matching=0) # Validator corr_validator = CorrValidator(corr_loader_val, val_corr_loss_fct_feat, val_corr_loss_fct_out, loss_fct, save_snapshot=True, extra_name_str='Corr') # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) # Clustering deepcluster = clustering.Kmeans(args['n_clusters']) if skip_clustering: deepcluster.set_index(cluster_centroids) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(save_folder, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') if len(args['snapshot']) == 0: f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) else: clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) val_iter = 0 curr_iter = start_iter while curr_iter <= args['max_iter']: if not skip_clustering: # Extract image features from reference images net.eval() net.output_features = True features, _ = extract_features_for_reference( net, model_config, ref_image_lists, corr_im_paths, ref_featurs_pos, max_num_features_per_image=args['max_features_per_image'], fraction_correspondeces=0.5) cluster_features = np.vstack(features) del features # cluster the features cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures( cluster_features, verbose=True, use_gpu=False) # save cluster centroids h5f = h5py.File( os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w') h5f.create_dataset('cluster_centroids', data=cluster_centroids) h5f.create_dataset('pca_transform_Amat', data=pca_info[0]) h5f.create_dataset('pca_transform_bvec', data=pca_info[1]) h5f.close() # Print distribution of clusters cluster_distribution, _ = np.histogram( cluster_indices, bins=np.arange(args['n_clusters'] + 1), density=True) str2write = 'cluster distribution ' + \ np.array2string(cluster_distribution, formatter={'float_kind': '{0:.8f}'.format}).replace('\n', ' ') print(str2write) f_handle.write(str2write + "\n") reinit_last_layers( net) # set last layer weight to a normal distribution # make a copy of current network state to do cluster assignment net_for_clustering = copy.deepcopy(net) else: skip_clustering = False optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] net.train() freeze_bn(net) net.output_features = False cluster_training_count = 0 # Train using the training correspondence set corr_train_loss = AverageMeter() seg_train_loss = AverageMeter() feature_train_loss = AverageMeter() while cluster_training_count < args[ 'cluster_interval'] and curr_iter <= args['max_iter']: # First extract cluster labels using saved network checkpoint net.to("cpu") net_for_clustering.to(device) net_for_clustering.eval() net_for_clustering.output_features = True if args['feature_hinge_loss_weight'] > 0: net_for_clustering.output_all = False data_samples = [] extract_label_count = 0 while (extract_label_count < args['chunk_size']) and ( cluster_training_count + extract_label_count < args['cluster_interval'] ) and (val_iter + extract_label_count < args['val_interval']) and ( extract_label_count + curr_iter <= args['max_iter']): img_ref, img_other, pts_ref, pts_other, _ = next( iter(corr_loader_train)) # Transfer data to device img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] with torch.no_grad(): features = net_for_clustering(img_ref) # assign feature to clusters for entire patch output = features.cpu().numpy() output_flat = output.reshape( (output.shape[0], output.shape[1], -1)) cluster_image = np.zeros( (output.shape[0], output.shape[2], output.shape[3]), dtype=np.int64) for b in range(output_flat.shape[0]): out_f = output_flat[b] out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1), pca_info=pca_info) cluster_labels = deepcluster.assign(out_f2) cluster_image[b] = cluster_labels.reshape( (output.shape[2], output.shape[3])) cluster_image = torch.from_numpy(cluster_image).to(device) # assign cluster to correspondence positions cluster_labels = assign_cluster_ids_to_correspondence_points( features, pts_ref, (deepcluster, pca_info), inds_other=pts_other, orig_im_size=(713, 713)) # Transfer data to cpu img_ref = img_ref.cpu() img_other = img_other.cpu() pts_ref = [p.cpu() for p in pts_ref] pts_other = [p.cpu() for p in pts_other] cluster_labels = [p.cpu() for p in cluster_labels] cluster_image = cluster_image.cpu() data_samples.append((img_ref, img_other, pts_ref, pts_other, cluster_labels, cluster_image)) extract_label_count += 1 net_for_clustering.to("cpu") net.to(device) for data_sample in data_samples: img_ref, img_other, pts_ref, pts_other, cluster_labels, cluster_image = data_sample # Transfer data to device img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] cluster_labels = [p.to(device) for p in cluster_labels] cluster_image = cluster_image.to(device) optimizer.zero_grad() if args['feature_hinge_loss_weight'] > 0: net.output_all = True # Randomization to decide if reference or target image should be used for training if args['fraction_reference_bp'] is None: # use both if args['feature_hinge_loss_weight'] > 0: out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net( img_ref) else: outputs_ref, aux_ref = net(img_ref) seg_main_loss = seg_loss_fct(outputs_ref, cluster_image) seg_aux_loss = seg_loss_fct(aux_ref, cluster_image) if args['feature_hinge_loss_weight'] > 0: out_feat_other, aux_feat_other, outputs_other, aux_other = net( img_other) else: outputs_other, aux_other = net(img_other) elif np.random.rand( 1)[0] < args['fraction_reference_bp']: # use reference if args['feature_hinge_loss_weight'] > 0: out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net( img_ref) else: outputs_ref, aux_ref = net(img_ref) seg_main_loss = seg_loss_fct(outputs_ref, cluster_image) seg_aux_loss = seg_loss_fct(aux_ref, cluster_image) with torch.no_grad(): if args['feature_hinge_loss_weight'] > 0: out_feat_other, aux_feat_other, outputs_other, aux_other = net( img_other) else: outputs_other, aux_other = net(img_other) else: # use target with torch.no_grad(): if args['feature_hinge_loss_weight'] > 0: out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net( img_ref) else: outputs_ref, aux_ref = net(img_ref) if args['feature_hinge_loss_weight'] > 0: out_feat_other, aux_feat_other, outputs_other, aux_other = net( img_other) else: outputs_other, aux_other = net(img_other) seg_main_loss = 0. seg_aux_loss = 0. if args['feature_hinge_loss_weight'] > 0: net.output_all = False main_loss, _ = loss_fct(outputs_ref, outputs_other, None, pts_ref, pts_other, cluster_labels=cluster_labels) aux_loss, _ = loss_fct(aux_ref, aux_other, None, pts_ref, pts_other, cluster_labels=cluster_labels) if args['feature_hinge_loss_weight'] > 0: feature_loss = feature_loss_fct(out_feat_ref, out_feat_other, pts_ref, pts_other, None) feature_loss_aux = feature_loss_fct( aux_feat_ref, aux_feat_other, pts_ref, pts_other, None) loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss) \ + args['feature_hinge_loss_weight']*(feature_loss + 0.4 * feature_loss_aux) else: feature_loss = 0. loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + \ args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss) loss.backward() optimizer.step() cluster_training_count += 1 corr_train_loss.update(main_loss.item(), 1) if type(seg_main_loss) == torch.Tensor: seg_train_loss.update(seg_main_loss.item(), 1) if type(feature_loss) == torch.Tensor: feature_train_loss.update(feature_loss.item(), 1) #################################################################################################### # LOGGING ETC #################################################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_corr_loss', corr_train_loss.avg, curr_iter) writer.add_scalar('train_seg_loss', seg_train_loss.avg, curr_iter) writer.add_scalar('train_feature_loss', feature_train_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (curr_iter + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % ( curr_iter + 1, args['max_iter'], seg_train_loss.avg, corr_train_loss.avg, feature_train_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if val_iter >= args['val_interval']: val_iter = 0 net_for_clustering.to(device) corr_validator.run(net, net_for_clustering, (deepcluster, pca_info), optimizer, args, curr_iter, save_folder, f_handle, writer=writer) net_for_clustering.to("cpu") if curr_iter > args['max_iter']: break # Post training f_handle.close() writer.close()
def cluster_images_in_folder_for_experiments(network_folder, args): # Predefined image sets and paths if len(args['img_set']) > 0: if args['img_set'] == 'cmu': dset = dataset_configs.CmuConfig() args['img_path'] = dset.test_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'cmu-annotated-test-images' elif args['img_set'] == 'cmu-train': dset = dataset_configs.CmuConfig() args['img_path'] = dset.train_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'cmu-annotated-train-images' elif args['img_set'] == 'rc': dset = dataset_configs.RobotcarConfig() args['img_path'] = dset.test_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'robotcar-test-results' elif args['img_set'] == 'cityscapes': dset = dataset_configs.CityscapesConfig() args['img_path'] = dset.val_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'cityscapes-val-results' elif args['img_set'] == 'vistas': dset = dataset_configs.VistasConfig() args['img_path'] = dset.val_im_folder args['img_ext'] = dset.im_file_ending args['save_folder_name'] = 'vistas-validation' if len(args['network_file']) < 1: print("Loading best network according to specified validation metric") if args['validation_metric'] == 'miou': with open(os.path.join(network_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) print("Network file %s" % (bestval['snapshot'])) elif args['validation_metric'] == 'acc': with open(os.path.join(network_folder, 'bestval_acc.txt')) as f: best_val_dict_str = f.read() bestval = eval(best_val_dict_str.rstrip()) print("Network file %s - val acc %s" % (bestval['snapshot'], bestval['acc'])) net_to_load = bestval['snapshot'] + '.pth' network_file = os.path.join(network_folder, net_to_load) else: print("Loading specified network") slash_inds = [ i for i in range(len(args['network_file'])) if args['network_file'].startswith('/', i) ] network_folder = args['network_file'][:slash_inds[-1]] network_file = args['network_file'] # folder should have same name as for trained network save_folder = os.path.join(network_folder, args['save_folder_name']) network_folder_name = network_folder.split('/')[-1] tmp = re.search(r"cn(\d+)-", network_folder_name) n_clusters = int(tmp.group(1)) cluster_images_in_folder(network_file, args['img_path'], save_folder, n_clusters, args)