示例#1
0
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    state_dict = torch.load(startnet)
    if 'resnet101' in startnet:
        load_resnet101_weights(net, state_dict)
    else:
        # needed since we slightly changed the structure of the network in pspnet
        state_dict = rename_keys_to_match(state_dict)
        # different amount of classes
        init_last_layers(state_dict, args['n_clusters'])

        net.load_state_dict(state_dict)  # load original weights

    start_iter = 0
    args['best_record'] = {
        'iter': 0,
        'val_loss_feat': 1e10,
        'val_loss_out': 1e10,
        'val_loss_cluster': 1e10
    }

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'pola':
        corr_set_config = data_configs.PolaConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    ref_image_lists = corr_set_config.reference_image_list

    # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True)
    # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}')
    # print(corr_set_config)
    # corr_im_paths = [corr_set_config.correspondence_im_path]
    # ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path,
    #                                                 corr_set_config.correspondence_im_path,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform_corr,
    #                                                 listfile=corr_set_config.correspondence_train_list_file)
    scales = [0, 1, 2, 3]

    # corr_set_train = Poladata.MonoDataset(corr_set_config,
    #                                       seg_folder = "media/HDD1/NsemSEG/Result_fold/" ,
    #                                       im_file_ending = ".jpg" )

    train_joint_transform = joint_transforms.Compose([
        # train_joint_transform_corr = corr_transforms.Compose([
        # corr_transforms.CorrResize(1024),
        # corr_transforms.CorrRandomCrop(713)
        joint_transforms.Resize(1024),
        joint_transforms.RandomCrop(713)
    ])

    sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255)

    # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder,
    #                                                 corr_set_config.train_im_folder,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform,
    #                                                 listfile=None)

    corr_set_train = Poladata.MonoDataset(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        im_file_ending=".jpg",
        id_to_trainid=None,
        joint_transform=train_joint_transform,
        sliding_crop=sliding_crop,
        transform=input_transform,
        target_transform=None,  #train_joint_transform,
        transform_before_sliding=None  #sliding_crop
    )
    #print (corr_set_train)
    # print(corr_set_train.mask)
    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)
    # corr_loader_train = input_transform(corr_loader_train)

    # print(corr_loader_train)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)

    # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter)
    # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:

        net.eval()
        net.output_features = True
        # max_num_features_per_image = args['max_features_per_image']
        # print('-----------------------------------------------------------------')
        # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ')
        # print('-----------------------------------------------------------------')

        # print('le next du loader es : ---------------')
        # print(next(iter(corr_loader_train)))

        # features, _ = extract_features_for_reference(net, model_config, ref_image_lists,
        #                                              corr_im_paths, ref_featurs_pos,
        #                                              max_num_features_per_image=args['max_features_per_image'],
        #                                              fraction_correspondeces=0.5)
        print(
            'ici on a la len de la ref im list --------------------------------------------------------'
        )
        print(len(ref_image_lists))
        features = extract_features_for_reference_nocorr(
            net,
            model_config,
            corr_set_train,
            10,
            max_num_features_per_image=args['max_features_per_image'])

        cluster_features = np.vstack(features)
        del features

        # cluster the features
        cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
            cluster_features, verbose=True, use_gpu=False)

        # save cluster centroids
        h5f = h5py.File(
            os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
        h5f.create_dataset('cluster_centroids', data=cluster_centroids)
        h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
        h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
        h5f.close()

        # Print distribution of clusters
        cluster_distribution, _ = np.histogram(
            cluster_indices,
            bins=np.arange(args['n_clusters'] + 1),
            density=True)
        str2write = 'cluster distribution ' + \
            np.array2string(cluster_distribution, formatter={
                            'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
        print(str2write)
        f_handle.write(str2write + "\n")

        # set last layer weight to a normal distribution
        reinit_last_layers(net)

        # make a copy of current network state to do cluster assignment
        net_for_clustering = copy.deepcopy(net)

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            print(
                'on rentre dans la boucle extract cluster_______________________________________________'
            )
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True

            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train))
                corr_loader_train = input_transform(corr_loader_train)
                print(
                    f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}'
                )
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # print('le next du loader es : ---------------')
                # print(next(iter(corr_loader_train)))
                # print(img_ref)

                # Transfer data to device
                img_ref = img_ref.to(device)

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                outputs_ref, aux_ref = net(img_ref)

                seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                loss = args['seg_loss_weight'] * \
                    (seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
示例#2
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                voc.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        joint_transform=train_joint_transform,
                        sliding_crop=sliding_crop,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      sliding_crop=sliding_crop,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)

    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          val_loader, visualize)
示例#3
0
def test_(img_path):
    net = PSPNet(num_classes=cityscapes.num_classes)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0:  [30, xxx] -> [15, ...], [15, ...] on 2 GPUs
        net = nn.DataParallel(net)

    print('load model ' + args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, args['exp_name'],
                                args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                cityscapes.ignore_label)

    img = Image.open(img_path).convert('RGB')
    img_slices, _, slices_info = sliding_crop(img, img.copy())
    img_slices = [val_input_transform(e) for e in img_slices]
    img = torch.stack(img_slices, 0)

    img = Variable(img, volatile=True).cuda()
    torch.no_grad()
    output = net(img)

    # prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
    # prediction = voc.colorize_mask(prediction)
    # prediction.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name + '.png'))

    img.transpose_(0, 1)
    slices_info.squeeze_(0)

    count = torch.zeros(args['longer_size'] // 2, args['longer_size']).cuda()
    output = torch.zeros(cityscapes.num_classes, args['longer_size'] // 2,
                         args['longer_size']).cuda()

    slice_batch_pixel_size = img.size(1) * img.size(3) * img.size(4)
    prediction = np.zeros((args['longer_size'] // 2, args['longer_size']),
                          dtype=int)

    for input_slice, info in zip(img, slices_info):
        input_slice = Variable(input_slice).cuda()
        output_slice = net(input_slice)
        assert output_slice.size()[1] == cityscapes.num_classes
        output[:, info[0]:info[1],
               info[2]:info[3]] += output_slice[0, :, :info[4], :info[5]].data
        count[info[0]:info[1], info[2]:info[3]] += 1

    output /= count
    prediction[:, :] = output.max(0)[1].squeeze_(0).cpu().numpy()

    test_dir = os.path.join(ckpt_path, args['exp_name'], 'test')
    img_name = os.path.basename(img_path)
    img_name = os.path.splitext(img_name)[0]
    print(img_name)
    if train_args['val_save_to_img_file']:
        check_mkdir(test_dir)

    val_visual = []
    prediction_pil = cityscapes.colorize_mask(prediction)
    if train_args['val_save_to_img_file']:
        prediction_pil.save(
            os.path.join(test_dir, '%s_prediction.png' % img_name))