def evaluate_segmented_images_for_experiments(network_folder,
                                              validation_metric, dataset):
    if dataset == 'cityscapes':
        dset = dataset_configs.CityscapesConfig()
        truth_folder = dset.val_seg_folder
        result_folder_name = 'cityscapes-val-results'
    elif dataset == 'wilddash':
        dset = dataset_configs.WilddashConfig()
        truth_folder = dset.val_seg_folder
        result_folder_name = 'wilddash_val_results'
    elif dataset == 'cmu':
        dset = dataset_configs.CmuConfig()
        truth_folder = dset.test_seg_folder
        result_folder_name = 'cmu-annotated-test-images'
    elif dataset == 'rc':
        dset = dataset_configs.RobotcarConfig()
        truth_folder = dset.test_seg_folder
        result_folder_name = 'robotcar-test-results'
    elif dataset == 'vistas':
        dset = dataset_configs.VistasConfig()
        truth_folder = dset.val_seg_folder
        result_folder_name = 'vistas-validation'

    if len(validation_metric) > 0:
        result_folder_name += '_' + validation_metric

    evaluate_segmented_images(os.path.join(network_folder, result_folder_name),
                              truth_folder,
                              dset.im_file_ending.replace('jpg', 'png'),
                              dset.seg_file_ending,
                              dset.id_to_trainid,
                              n_classes=dset.n_classes)
def write_lists_for_corr(corr_set):
    if corr_set == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif corr_set == 'rc':
        corr_set_config = data_configs.RobotcarConfig()

    n_samples_to_include = 1000000
    fraction_training = 0.7

    np.random.seed(0)
    f_name_list = [fn for fn in os.listdir(corr_set_config.correspondence_path) if fn.endswith('mat')]

    n_samples = min(n_samples_to_include, len(f_name_list))

    n_training = ceil(fraction_training*n_samples)
    n_validation = n_samples - n_training

    training_ids = np.random.choice(len(f_name_list), n_training)
    ids_left_for_validation = set(range(len(f_name_list))) - set(training_ids)
    validation_ids = np.random.choice(list(ids_left_for_validation), n_validation)

    f_train = open(corr_set_config.correspondence_train_list_file, 'w')
    f_val = open(corr_set_config.correspondence_val_list_file, 'w')
    for i in training_ids:
        f_train.write(f_name_list[i] + '\n')
    for i in validation_ids:
        f_val.write(f_name_list[i] + '\n')
    f_train.close()
    f_val.close()
def write_reference_im_list(corr_set):
    if corr_set == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif corr_set == 'rc':
        corr_set_config = data_configs.RobotcarConfig()

    # FILES FROM LIST
    f_name_list = []
    with open(corr_set_config.correspondence_train_list_file) as f:
        for line in f:
            f_name_list.append(line.strip())

    # other
    ref_file_names = set()  # use set to automatically handle duplicates

    # run
    it = 0
    for f_name in f_name_list:
        mat_content = {}
        ff = h5py.File(
            os.path.join(corr_set_config.correspondence_path, f_name), 'r')
        for k, v in ff.items():
            mat_content[k] = np.array(v)

        im1name = ''.join(
            chr(a) for a in mat_content['im_i_path'])  # convert to string
        ref_file_names.add(im1name)
        it += 1

        if (it % 100) == 0:
            print('%d/%d' % (it, len(f_name_list)))

    print('Writing file %s ' % corr_set_config.reference_image_list)
    with open(corr_set_config.reference_image_list, 'w') as f:
        for fname in ref_file_names:
            f.write('%s\n' % fname)
Exemple #4
0
def find_non_stationary_clusters(args):
    if args['use_gpu']:
        print("Using CUDA" if torch.cuda.is_available() else "Using CPU")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    network_folder_name = args['folder'].split('/')[-1]
    tmp = re.search(r"cn(\d+)-", network_folder_name)
    n_clusters = int(tmp.group(1))

    save_folder = os.path.join(args['dest_root'], network_folder_name)
    if os.path.exists(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')):
        print('{} already exists. skipping'.format(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')))
        return

    check_mkdir(save_folder)

    with open(os.path.join(args['folder'], 'bestval.txt')) as f:
        best_val_dict_str = f.read()
        bestval = eval(best_val_dict_str.rstrip())

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=n_clusters,
        for_clustering=False,
        output_features=False,
        use_original_base=args['use_original_base']).to(device)
    net.load_state_dict(
        torch.load(os.path.join(args['folder'],
                                bestval['snapshot'] + '.pth')))  # load weights
    net.eval()

    # copy network file to save location
    copyfile(os.path.join(args['folder'], bestval['snapshot'] + '.pth'),
             os.path.join(save_folder, 'weights.pth'))

    if args['only_copy_weights']:
        print('Only copying weights')
        return

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config1.correspondence_val_list_file)
        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config2.correspondence_val_list_file)
        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config.correspondence_val_list_file)

    # Segmentor
    segmentor = Segmentor(net, n_clusters, n_slices_per_pass=4)

    # save args
    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    cluster_histogram_for_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)
    cluster_histogram_non_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)

    for i in range(0, len(corr_set_val), args['step']):
        img1, img2, pts1, pts2, _ = corr_set_val[i]
        seg1 = segmentor.run_and_save(
            img1,
            None,
            pre_sliding_crop_transform=pre_validation_transform,
            input_transform=input_transform,
            sliding_crop=sliding_crop_im,
            use_gpu=args['use_gpu'])

        seg1 = np.array(seg1)
        corr_loc_mask = np.zeros(seg1.shape, dtype=np.bool)

        valid_inds = (pts1[0, :] >= 0) & (pts1[0, :] < seg1.shape[1]) & (
            pts1[1, :] >= 0) & (pts1[1, :] < seg1.shape[0])

        pts1 = pts1[:, valid_inds]
        for j in range(pts1.shape[1]):
            pt = pts1[:, j]
            corr_loc_mask[pt[1], pt[0]] = True

        cluster_ids_corr = seg1[corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_corr, np.arange(n_clusters + 1))
        cluster_histogram_for_correspondences += hist_tmp

        cluster_ids_no_corr = seg1[~corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_no_corr,
                                   np.arange(n_clusters + 1))
        cluster_histogram_non_correspondences += hist_tmp

        if ((i + 1) % 100) < args['step']:
            print('{}/{}'.format(i + 1, len(corr_set_val)))

    np.save(os.path.join(save_folder, 'cluster_histogram_for_corr.npy'),
            cluster_histogram_for_correspondences)
    np.save(os.path.join(save_folder, 'cluster_histogram_non_corr.npy'),
            cluster_histogram_non_correspondences)
    frac = cluster_histogram_for_correspondences / \
        (cluster_histogram_for_correspondences +
         cluster_histogram_non_correspondences)
    stationary_inds = np.argwhere(frac > 0.01)
    np.save(os.path.join(save_folder, 'stationary_inds.npy'), stationary_inds)
    print('{} stationary clusters out of {}'.format(
        len(stationary_inds), len(cluster_histogram_for_correspondences)))
Exemple #5
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
def segment_images_in_folder_for_experiments(network_folder, args):
    # Predefined image sets and paths
    if len(args['img_set']) > 0:
        if args['img_set'] == 'cmu':
            dset = dataset_configs.CmuConfig()
            args['img_path'] = dset.test_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'cmu-annotated-test-images'
        elif args['img_set'] == 'wilddash':
            dset = dataset_configs.WilddashConfig()
            args['img_path'] = dset.val_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'wilddash_val_results'
        elif args['img_set'] == 'rc':
            dset = dataset_configs.RobotcarConfig()
            args['img_path'] = dset.test_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'robotcar-test-results'
        elif args['img_set'] == 'cityscapes':
            dset = dataset_configs.CityscapesConfig()
            args['img_path'] = dset.val_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'cityscapes-val-results'
        elif args['img_set'] == 'cmu-vis':
            dset = dataset_configs.CmuConfig()
            args['img_path'] = dset.vis_test_im_folder
            args['img_ext'] = 'jpg'
            args['save_folder_name'] = 'cmu-visual-test-images'
        elif args['img_set'] == 'rc-vis':
            dset = dataset_configs.RobotcarConfig()
            args['img_path'] = dset.vis_test_im_folder
            args['img_ext'] = 'jpg'
            args['save_folder_name'] = 'robotcar-visual-test-images'
        elif args['img_set'] == 'vistas':
            dset = dataset_configs.VistasConfig()
            args['img_path'] = dset.val_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'vistas-validation'

    if len(args['network_file']) < 1:
        print("Loading best network according to specified validation metric")
        if args['validation_metric'] == 'miou':
            with open(os.path.join(network_folder, 'bestval.txt')) as f:
                best_val_dict_str = f.read()
                bestval = eval(best_val_dict_str.rstrip())

            print("Network file %s - val miou %s" %
                  (bestval['snapshot'], bestval['mean_iu']))
        elif args['validation_metric'] == 'acc':
            with open(os.path.join(network_folder, 'bestval_acc.txt')) as f:
                best_val_dict_str = f.read()
                bestval = eval(best_val_dict_str.rstrip())

            print("Network file %s - val acc %s" %
                  (bestval['snapshot'], bestval['acc']))

        net_to_load = bestval['snapshot'] + '.pth'
        network_file = os.path.join(network_folder, net_to_load)

    else:
        print("Loading specified network")
        slash_inds = [
            i for i in range(len(args['network_file']))
            if args['network_file'].startswith('/', i)
        ]
        network_folder = args['network_file'][:slash_inds[-1]]
        network_file = args['network_file']

    # folder should have same name as for trained network
    if args['validation_metric'] == 'miou':
        save_folder = os.path.join(network_folder, args['save_folder_name'])
    elif args['validation_metric'] == 'acc':
        save_folder = os.path.join(network_folder,
                                   args['save_folder_name'] + '_acc')

    segment_images_in_folder(network_file, args['img_path'], save_folder, args)
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    state_dict = torch.load(startnet)
    if 'resnet101' in startnet:
        load_resnet101_weights(net, state_dict)
    else:
        # needed since we slightly changed the structure of the network in pspnet
        state_dict = rename_keys_to_match(state_dict)
        # different amount of classes
        init_last_layers(state_dict, args['n_clusters'])

        net.load_state_dict(state_dict)  # load original weights

    start_iter = 0
    args['best_record'] = {
        'iter': 0,
        'val_loss_feat': 1e10,
        'val_loss_out': 1e10,
        'val_loss_cluster': 1e10
    }

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'pola':
        corr_set_config = data_configs.PolaConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    ref_image_lists = corr_set_config.reference_image_list

    # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True)
    # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}')
    # print(corr_set_config)
    # corr_im_paths = [corr_set_config.correspondence_im_path]
    # ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path,
    #                                                 corr_set_config.correspondence_im_path,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform_corr,
    #                                                 listfile=corr_set_config.correspondence_train_list_file)
    scales = [0, 1, 2, 3]

    # corr_set_train = Poladata.MonoDataset(corr_set_config,
    #                                       seg_folder = "media/HDD1/NsemSEG/Result_fold/" ,
    #                                       im_file_ending = ".jpg" )

    train_joint_transform = joint_transforms.Compose([
        # train_joint_transform_corr = corr_transforms.Compose([
        # corr_transforms.CorrResize(1024),
        # corr_transforms.CorrRandomCrop(713)
        joint_transforms.Resize(1024),
        joint_transforms.RandomCrop(713)
    ])

    sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255)

    # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder,
    #                                                 corr_set_config.train_im_folder,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform,
    #                                                 listfile=None)

    corr_set_train = Poladata.MonoDataset(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        im_file_ending=".jpg",
        id_to_trainid=None,
        joint_transform=train_joint_transform,
        sliding_crop=sliding_crop,
        transform=input_transform,
        target_transform=None,  #train_joint_transform,
        transform_before_sliding=None  #sliding_crop
    )
    #print (corr_set_train)
    # print(corr_set_train.mask)
    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)
    # corr_loader_train = input_transform(corr_loader_train)

    # print(corr_loader_train)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)

    # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter)
    # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:

        net.eval()
        net.output_features = True
        # max_num_features_per_image = args['max_features_per_image']
        # print('-----------------------------------------------------------------')
        # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ')
        # print('-----------------------------------------------------------------')

        # print('le next du loader es : ---------------')
        # print(next(iter(corr_loader_train)))

        # features, _ = extract_features_for_reference(net, model_config, ref_image_lists,
        #                                              corr_im_paths, ref_featurs_pos,
        #                                              max_num_features_per_image=args['max_features_per_image'],
        #                                              fraction_correspondeces=0.5)
        print(
            'ici on a la len de la ref im list --------------------------------------------------------'
        )
        print(len(ref_image_lists))
        features = extract_features_for_reference_nocorr(
            net,
            model_config,
            corr_set_train,
            10,
            max_num_features_per_image=args['max_features_per_image'])

        cluster_features = np.vstack(features)
        del features

        # cluster the features
        cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
            cluster_features, verbose=True, use_gpu=False)

        # save cluster centroids
        h5f = h5py.File(
            os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
        h5f.create_dataset('cluster_centroids', data=cluster_centroids)
        h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
        h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
        h5f.close()

        # Print distribution of clusters
        cluster_distribution, _ = np.histogram(
            cluster_indices,
            bins=np.arange(args['n_clusters'] + 1),
            density=True)
        str2write = 'cluster distribution ' + \
            np.array2string(cluster_distribution, formatter={
                            'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
        print(str2write)
        f_handle.write(str2write + "\n")

        # set last layer weight to a normal distribution
        reinit_last_layers(net)

        # make a copy of current network state to do cluster assignment
        net_for_clustering = copy.deepcopy(net)

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            print(
                'on rentre dans la boucle extract cluster_______________________________________________'
            )
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True

            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train))
                corr_loader_train = input_transform(corr_loader_train)
                print(
                    f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}'
                )
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # print('le next du loader es : ---------------')
                # print(next(iter(corr_loader_train)))
                # print(img_ref)

                # Transfer data to device
                img_ref = img_ref.to(device)

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                outputs_ref, aux_ref = net(img_ref)

                seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                loss = args['seg_loss_weight'] * \
                    (seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)
    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        if 'resnet101' in startnet:
            load_resnet101_weights(net, state_dict)
        else:
            # needed since we slightly changed the structure of the network in pspnet
            state_dict = rename_keys_to_match(state_dict)
            init_last_layers(state_dict,
                             args['n_clusters'])  # different amount of classes

            net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss_feat': 1e10,
            'val_loss_out': 1e10,
            'val_loss_cluster': 1e10
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

        if start_iter >= args['max_iter']:
            return

        if (start_iter % args['cluster_interval']) == 0:
            skip_clustering = False
        else:
            skip_clustering = True
            last_cluster_network_snapshot_iter = (
                start_iter //
                args['cluster_interval']) * args['cluster_interval']

            # load cluster info
            cluster_info = {}
            f = h5py.File(
                os.path.join(
                    save_folder, 'centroids_{}.h5'.format(
                        last_cluster_network_snapshot_iter)), 'r')
            for k, v in f.items():
                cluster_info[k] = np.array(v)
            cluster_centroids = cluster_info['cluster_centroids']
            pca_info = [
                cluster_info['pca_transform_Amat'],
                cluster_info['pca_transform_bvec']
            ]

            # load network that was used for last clustering
            net_for_clustering = model_config.init_network(
                n_classes=args['n_clusters'],
                for_clustering=True,
                output_features=True,
                use_original_base=args['use_original_base'])

            if last_cluster_network_snapshot_iter == 0:
                state_dict = torch.load(
                    startnet, map_location=lambda storage, loc: storage)
                if 'resnet101' in startnet:
                    load_resnet101_weights(net_for_clustering, state_dict)
                else:
                    # needed since we slightly changed the structure of the network in pspnet
                    state_dict = rename_keys_to_match(state_dict)
                    init_last_layers(
                        state_dict,
                        args['n_clusters'])  # different amount of classes

                    net_for_clustering.load_state_dict(
                        state_dict)  # load original weights
            else:
                cluster_network_weights = get_network_name_from_iteration(
                    save_folder, last_cluster_network_snapshot_iter)
                net_for_clustering.load_state_dict(
                    torch.load(os.path.join(save_folder,
                                            cluster_network_weights),
                               map_location=lambda storage, loc: storage)
                )  # load weights

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    if args['corr_set'] == 'both':
        ref_image_lists = [
            corr_set_config1.reference_image_list,
            corr_set_config2.reference_image_list
        ]
        corr_im_paths = [
            corr_set_config1.correspondence_im_path,
            corr_set_config2.correspondence_im_path
        ]
        ref_featurs_pos = [
            corr_set_config1.reference_feature_poitions,
            corr_set_config2.reference_feature_poitions
        ]
    else:
        ref_image_lists = [corr_set_config.reference_image_list]
        corr_im_paths = [corr_set_config.correspondence_im_path]
        ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # Correspondences for training
    if args['corr_set'] == 'both':
        corr_set_train1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_train_list_file)
        corr_set_train2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_train_list_file)

        corr_set_train = merged.Merged([corr_set_train1, corr_set_train2])
    else:
        corr_set_train = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_train_list_file)

    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)

    # Correspondences for validation
    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_val_list_file)

        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_val_list_file)

        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_val_list_file)

    corr_loader_val = DataLoader(corr_set_val,
                                 batch_size=1,
                                 num_workers=args['n_workers'],
                                 shuffle=False)

    # Loss setup
    val_corr_loss_fct_feat = FeatureLoss(
        input_size=[713, 713],
        loss_type=args['feature_distance_measure'],
        feat_dist_threshold_match=0.8,
        feat_dist_threshold_nomatch=0.2,
        n_not_matching=0)

    val_corr_loss_fct_out = FeatureLoss(input_size=[713, 713],
                                        loss_type='KL',
                                        feat_dist_threshold_match=0.8,
                                        feat_dist_threshold_nomatch=0.2,
                                        n_not_matching=0)

    loss_fct = ClusterCorrespondenceLoss(input_size=[713, 713],
                                         size_average=True).to(device)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    if args['feature_hinge_loss_weight'] > 0:
        feature_loss_fct = FeatureLoss(input_size=[713, 713],
                                       loss_type='hingeF',
                                       feat_dist_threshold_match=0.8,
                                       feat_dist_threshold_nomatch=0.2,
                                       n_not_matching=0)

    # Validator
    corr_validator = CorrValidator(corr_loader_val,
                                   val_corr_loss_fct_feat,
                                   val_corr_loss_fct_out,
                                   loss_fct,
                                   save_snapshot=True,
                                   extra_name_str='Corr')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:
        if not skip_clustering:
            # Extract image features from reference images
            net.eval()
            net.output_features = True

            features, _ = extract_features_for_reference(
                net,
                model_config,
                ref_image_lists,
                corr_im_paths,
                ref_featurs_pos,
                max_num_features_per_image=args['max_features_per_image'],
                fraction_correspondeces=0.5)

            cluster_features = np.vstack(features)
            del features

            # cluster the features
            cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
                cluster_features, verbose=True, use_gpu=False)

            # save cluster centroids
            h5f = h5py.File(
                os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
            h5f.create_dataset('cluster_centroids', data=cluster_centroids)
            h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
            h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
            h5f.close()

            # Print distribution of clusters
            cluster_distribution, _ = np.histogram(
                cluster_indices,
                bins=np.arange(args['n_clusters'] + 1),
                density=True)
            str2write = 'cluster distribution ' + \
                np.array2string(cluster_distribution, formatter={'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
            print(str2write)
            f_handle.write(str2write + "\n")

            reinit_last_layers(
                net)  # set last layer weight to a normal distribution

            # make a copy of current network state to do cluster assignment
            net_for_clustering = copy.deepcopy(net)
        else:
            skip_clustering = False

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True
            if args['feature_hinge_loss_weight'] > 0:
                net_for_clustering.output_all = False
            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                img_other = img_other.cpu()
                pts_ref = [p.cpu() for p in pts_ref]
                pts_other = [p.cpu() for p in pts_other]
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, img_other, pts_ref, pts_other,
                                     cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, img_other, pts_ref, pts_other, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = True

                # Randomization to decide if reference or target image should be used for training
                if args['fraction_reference_bp'] is None:  # use both
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                elif np.random.rand(
                        1)[0] < args['fraction_reference_bp']:  # use reference
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                                img_other)
                        else:
                            outputs_other, aux_other = net(img_other)
                else:  # use target
                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                                img_ref)
                        else:
                            outputs_ref, aux_ref = net(img_ref)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                    seg_main_loss = 0.
                    seg_aux_loss = 0.

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = False

                main_loss, _ = loss_fct(outputs_ref,
                                        outputs_other,
                                        None,
                                        pts_ref,
                                        pts_other,
                                        cluster_labels=cluster_labels)
                aux_loss, _ = loss_fct(aux_ref,
                                       aux_other,
                                       None,
                                       pts_ref,
                                       pts_other,
                                       cluster_labels=cluster_labels)

                if args['feature_hinge_loss_weight'] > 0:
                    feature_loss = feature_loss_fct(out_feat_ref,
                                                    out_feat_other, pts_ref,
                                                    pts_other, None)
                    feature_loss_aux = feature_loss_fct(
                        aux_feat_ref, aux_feat_other, pts_ref, pts_other, None)
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss) \
                        + args['feature_hinge_loss_weight']*(feature_loss + 0.4 * feature_loss_aux)
                else:
                    feature_loss = 0.
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + \
                        args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                corr_train_loss.update(main_loss.item(), 1)
                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)
                if type(feature_loss) == torch.Tensor:
                    feature_train_loss.update(feature_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_corr_loss', corr_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_feature_loss', feature_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        corr_train_loss.avg, feature_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if val_iter >= args['val_interval']:
                    val_iter = 0
                    net_for_clustering.to(device)
                    corr_validator.run(net,
                                       net_for_clustering,
                                       (deepcluster, pca_info),
                                       optimizer,
                                       args,
                                       curr_iter,
                                       save_folder,
                                       f_handle,
                                       writer=writer)
                    net_for_clustering.to("cpu")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
Exemple #9
0
def cluster_images_in_folder_for_experiments(network_folder, args):
    # Predefined image sets and paths
    if len(args['img_set']) > 0:
        if args['img_set'] == 'cmu':
            dset = dataset_configs.CmuConfig()
            args['img_path'] = dset.test_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'cmu-annotated-test-images'
        elif args['img_set'] == 'cmu-train':
            dset = dataset_configs.CmuConfig()
            args['img_path'] = dset.train_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'cmu-annotated-train-images'
        elif args['img_set'] == 'rc':
            dset = dataset_configs.RobotcarConfig()
            args['img_path'] = dset.test_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'robotcar-test-results'
        elif args['img_set'] == 'cityscapes':
            dset = dataset_configs.CityscapesConfig()
            args['img_path'] = dset.val_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'cityscapes-val-results'
        elif args['img_set'] == 'vistas':
            dset = dataset_configs.VistasConfig()
            args['img_path'] = dset.val_im_folder
            args['img_ext'] = dset.im_file_ending
            args['save_folder_name'] = 'vistas-validation'

    if len(args['network_file']) < 1:
        print("Loading best network according to specified validation metric")
        if args['validation_metric'] == 'miou':
            with open(os.path.join(network_folder, 'bestval.txt')) as f:
                best_val_dict_str = f.read()
                bestval = eval(best_val_dict_str.rstrip())

            print("Network file %s" % (bestval['snapshot']))
        elif args['validation_metric'] == 'acc':
            with open(os.path.join(network_folder, 'bestval_acc.txt')) as f:
                best_val_dict_str = f.read()
                bestval = eval(best_val_dict_str.rstrip())

            print("Network file %s - val acc %s" %
                  (bestval['snapshot'], bestval['acc']))

        net_to_load = bestval['snapshot'] + '.pth'
        network_file = os.path.join(network_folder, net_to_load)

    else:
        print("Loading specified network")
        slash_inds = [
            i for i in range(len(args['network_file']))
            if args['network_file'].startswith('/', i)
        ]
        network_folder = args['network_file'][:slash_inds[-1]]
        network_file = args['network_file']

    # folder should have same name as for trained network
    save_folder = os.path.join(network_folder, args['save_folder_name'])

    network_folder_name = network_folder.split('/')[-1]
    tmp = re.search(r"cn(\d+)-", network_folder_name)
    n_clusters = int(tmp.group(1))
    cluster_images_in_folder(network_file, args['img_path'], save_folder,
                             n_clusters, args)