def train_with_clustering(save_folder, tmp_seg_folder, startnet, args): print(save_folder.split('/')[-1]) skip_clustering = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) check_mkdir(tmp_seg_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=args['n_clusters'], for_clustering=True, output_features=True, use_original_base=args['use_original_base']).to(device) state_dict = torch.load(startnet) if 'resnet101' in startnet: load_resnet101_weights(net, state_dict) else: # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) # different amount of classes init_last_layers(state_dict, args['n_clusters']) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss_feat': 1e10, 'val_loss_out': 1e10, 'val_loss_cluster': 1e10 } # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'pola': corr_set_config = data_configs.PolaConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() ref_image_lists = corr_set_config.reference_image_list # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True) # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}') # print(corr_set_config) # corr_im_paths = [corr_set_config.correspondence_im_path] # ref_featurs_pos = [corr_set_config.reference_feature_poitions] input_transform = model_config.input_transform #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path, # corr_set_config.correspondence_im_path, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform_corr, # listfile=corr_set_config.correspondence_train_list_file) scales = [0, 1, 2, 3] # corr_set_train = Poladata.MonoDataset(corr_set_config, # seg_folder = "media/HDD1/NsemSEG/Result_fold/" , # im_file_ending = ".jpg" ) train_joint_transform = joint_transforms.Compose([ # train_joint_transform_corr = corr_transforms.Compose([ # corr_transforms.CorrResize(1024), # corr_transforms.CorrRandomCrop(713) joint_transforms.Resize(1024), joint_transforms.RandomCrop(713) ]) sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255) # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder, # corr_set_config.train_im_folder, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform, # listfile=None) corr_set_train = Poladata.MonoDataset( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, im_file_ending=".jpg", id_to_trainid=None, joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=input_transform, target_transform=None, #train_joint_transform, transform_before_sliding=None #sliding_crop ) #print (corr_set_train) # print(corr_set_train.mask) corr_loader_train = DataLoader(corr_set_train, batch_size=1, num_workers=args['n_workers'], shuffle=True) # corr_loader_train = input_transform(corr_loader_train) # print(corr_loader_train) seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean') # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) # Clustering deepcluster = clustering.Kmeans(args['n_clusters']) if skip_clustering: deepcluster.set_index(cluster_centroids) open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) val_iter = 0 curr_iter = start_iter while curr_iter <= args['max_iter']: net.eval() net.output_features = True # max_num_features_per_image = args['max_features_per_image'] # print('-----------------------------------------------------------------') # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ') # print('-----------------------------------------------------------------') # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # features, _ = extract_features_for_reference(net, model_config, ref_image_lists, # corr_im_paths, ref_featurs_pos, # max_num_features_per_image=args['max_features_per_image'], # fraction_correspondeces=0.5) print( 'ici on a la len de la ref im list --------------------------------------------------------' ) print(len(ref_image_lists)) features = extract_features_for_reference_nocorr( net, model_config, corr_set_train, 10, max_num_features_per_image=args['max_features_per_image']) cluster_features = np.vstack(features) del features # cluster the features cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures( cluster_features, verbose=True, use_gpu=False) # save cluster centroids h5f = h5py.File( os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w') h5f.create_dataset('cluster_centroids', data=cluster_centroids) h5f.create_dataset('pca_transform_Amat', data=pca_info[0]) h5f.create_dataset('pca_transform_bvec', data=pca_info[1]) h5f.close() # Print distribution of clusters cluster_distribution, _ = np.histogram( cluster_indices, bins=np.arange(args['n_clusters'] + 1), density=True) str2write = 'cluster distribution ' + \ np.array2string(cluster_distribution, formatter={ 'float_kind': '{0:.8f}'.format}).replace('\n', ' ') print(str2write) f_handle.write(str2write + "\n") # set last layer weight to a normal distribution reinit_last_layers(net) # make a copy of current network state to do cluster assignment net_for_clustering = copy.deepcopy(net) optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] net.train() freeze_bn(net) net.output_features = False cluster_training_count = 0 # Train using the training correspondence set corr_train_loss = AverageMeter() seg_train_loss = AverageMeter() feature_train_loss = AverageMeter() while cluster_training_count < args[ 'cluster_interval'] and curr_iter <= args['max_iter']: # First extract cluster labels using saved network checkpoint print( 'on rentre dans la boucle extract cluster_______________________________________________' ) net.to("cpu") net_for_clustering.to(device) net_for_clustering.eval() net_for_clustering.output_features = True data_samples = [] extract_label_count = 0 while (extract_label_count < args['chunk_size']) and ( cluster_training_count + extract_label_count < args['cluster_interval'] ) and (val_iter + extract_label_count < args['val_interval']) and ( extract_label_count + curr_iter <= args['max_iter']): # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train)) corr_loader_train = input_transform(corr_loader_train) print( f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}' ) img_ref, img_other, pts_ref, pts_other, _ = next( iter(corr_loader_train)) # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # print(img_ref) # Transfer data to device img_ref = img_ref.to(device) with torch.no_grad(): features = net_for_clustering(img_ref) # assign feature to clusters for entire patch output = features.cpu().numpy() output_flat = output.reshape( (output.shape[0], output.shape[1], -1)) cluster_image = np.zeros( (output.shape[0], output.shape[2], output.shape[3]), dtype=np.int64) for b in range(output_flat.shape[0]): out_f = output_flat[b] out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1), pca_info=pca_info) cluster_labels = deepcluster.assign(out_f2) cluster_image[b] = cluster_labels.reshape( (output.shape[2], output.shape[3])) cluster_image = torch.from_numpy(cluster_image).to(device) # assign cluster to correspondence positions cluster_labels = assign_cluster_ids_to_correspondence_points( features, pts_ref, (deepcluster, pca_info), inds_other=pts_other, orig_im_size=(713, 713)) # Transfer data to cpu img_ref = img_ref.cpu() cluster_labels = [p.cpu() for p in cluster_labels] cluster_image = cluster_image.cpu() data_samples.append((img_ref, cluster_labels, cluster_image)) extract_label_count += 1 net_for_clustering.to("cpu") net.to(device) for data_sample in data_samples: img_ref, cluster_labels, cluster_image = data_sample # Transfer data to device img_ref = img_ref.to(device) cluster_labels = [p.to(device) for p in cluster_labels] cluster_image = cluster_image.to(device) optimizer.zero_grad() outputs_ref, aux_ref = net(img_ref) seg_main_loss = seg_loss_fct(outputs_ref, cluster_image) seg_aux_loss = seg_loss_fct(aux_ref, cluster_image) loss = args['seg_loss_weight'] * \ (seg_main_loss + 0.4 * seg_aux_loss) loss.backward() optimizer.step() cluster_training_count += 1 if type(seg_main_loss) == torch.Tensor: seg_train_loss.update(seg_main_loss.item(), 1) #################################################################################################### # LOGGING ETC #################################################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss', seg_train_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (curr_iter + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % ( curr_iter + 1, args['max_iter'], seg_train_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if curr_iter > args['max_iter']: break # Post training f_handle.close() writer.close()
def main(): net = PSPNet(num_classes=voc.num_classes).cuda() if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) curr_epoch = 1 args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], voc.ignore_label) train_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() visualize = standard_transforms.Compose([ standard_transforms.Resize(args['val_img_display_size']), standard_transforms.ToTensor() ]) train_set = voc.VOC('train', joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=1, shuffle=True, drop_last=True) val_set = voc.VOC('val', transform=val_input_transform, sliding_crop=sliding_crop, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=1, shuffle=False, drop_last=True) criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize)
def test_(img_path): net = PSPNet(num_classes=cityscapes.num_classes) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0: [30, xxx] -> [15, ...], [15, ...] on 2 GPUs net = nn.DataParallel(net) print('load model ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot']))) net.eval() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], cityscapes.ignore_label) img = Image.open(img_path).convert('RGB') img_slices, _, slices_info = sliding_crop(img, img.copy()) img_slices = [val_input_transform(e) for e in img_slices] img = torch.stack(img_slices, 0) img = Variable(img, volatile=True).cuda() torch.no_grad() output = net(img) # prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() # prediction = voc.colorize_mask(prediction) # prediction.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name + '.png')) img.transpose_(0, 1) slices_info.squeeze_(0) count = torch.zeros(args['longer_size'] // 2, args['longer_size']).cuda() output = torch.zeros(cityscapes.num_classes, args['longer_size'] // 2, args['longer_size']).cuda() slice_batch_pixel_size = img.size(1) * img.size(3) * img.size(4) prediction = np.zeros((args['longer_size'] // 2, args['longer_size']), dtype=int) for input_slice, info in zip(img, slices_info): input_slice = Variable(input_slice).cuda() output_slice = net(input_slice) assert output_slice.size()[1] == cityscapes.num_classes output[:, info[0]:info[1], info[2]:info[3]] += output_slice[0, :, :info[4], :info[5]].data count[info[0]:info[1], info[2]:info[3]] += 1 output /= count prediction[:, :] = output.max(0)[1].squeeze_(0).cpu().numpy() test_dir = os.path.join(ckpt_path, args['exp_name'], 'test') img_name = os.path.basename(img_path) img_name = os.path.splitext(img_name)[0] print(img_name) if train_args['val_save_to_img_file']: check_mkdir(test_dir) val_visual = [] prediction_pil = cityscapes.colorize_mask(prediction) if train_args['val_save_to_img_file']: prediction_pil.save( os.path.join(test_dir, '%s_prediction.png' % img_name))