예제 #1
0
def find_non_stationary_clusters(args):
    if args['use_gpu']:
        print("Using CUDA" if torch.cuda.is_available() else "Using CPU")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    network_folder_name = args['folder'].split('/')[-1]
    tmp = re.search(r"cn(\d+)-", network_folder_name)
    n_clusters = int(tmp.group(1))

    save_folder = os.path.join(args['dest_root'], network_folder_name)
    if os.path.exists(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')):
        print('{} already exists. skipping'.format(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')))
        return

    check_mkdir(save_folder)

    with open(os.path.join(args['folder'], 'bestval.txt')) as f:
        best_val_dict_str = f.read()
        bestval = eval(best_val_dict_str.rstrip())

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=n_clusters,
        for_clustering=False,
        output_features=False,
        use_original_base=args['use_original_base']).to(device)
    net.load_state_dict(
        torch.load(os.path.join(args['folder'],
                                bestval['snapshot'] + '.pth')))  # load weights
    net.eval()

    # copy network file to save location
    copyfile(os.path.join(args['folder'], bestval['snapshot'] + '.pth'),
             os.path.join(save_folder, 'weights.pth'))

    if args['only_copy_weights']:
        print('Only copying weights')
        return

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config1.correspondence_val_list_file)
        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config2.correspondence_val_list_file)
        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config.correspondence_val_list_file)

    # Segmentor
    segmentor = Segmentor(net, n_clusters, n_slices_per_pass=4)

    # save args
    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    cluster_histogram_for_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)
    cluster_histogram_non_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)

    for i in range(0, len(corr_set_val), args['step']):
        img1, img2, pts1, pts2, _ = corr_set_val[i]
        seg1 = segmentor.run_and_save(
            img1,
            None,
            pre_sliding_crop_transform=pre_validation_transform,
            input_transform=input_transform,
            sliding_crop=sliding_crop_im,
            use_gpu=args['use_gpu'])

        seg1 = np.array(seg1)
        corr_loc_mask = np.zeros(seg1.shape, dtype=np.bool)

        valid_inds = (pts1[0, :] >= 0) & (pts1[0, :] < seg1.shape[1]) & (
            pts1[1, :] >= 0) & (pts1[1, :] < seg1.shape[0])

        pts1 = pts1[:, valid_inds]
        for j in range(pts1.shape[1]):
            pt = pts1[:, j]
            corr_loc_mask[pt[1], pt[0]] = True

        cluster_ids_corr = seg1[corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_corr, np.arange(n_clusters + 1))
        cluster_histogram_for_correspondences += hist_tmp

        cluster_ids_no_corr = seg1[~corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_no_corr,
                                   np.arange(n_clusters + 1))
        cluster_histogram_non_correspondences += hist_tmp

        if ((i + 1) % 100) < args['step']:
            print('{}/{}'.format(i + 1, len(corr_set_val)))

    np.save(os.path.join(save_folder, 'cluster_histogram_for_corr.npy'),
            cluster_histogram_for_correspondences)
    np.save(os.path.join(save_folder, 'cluster_histogram_non_corr.npy'),
            cluster_histogram_non_correspondences)
    frac = cluster_histogram_for_correspondences / \
        (cluster_histogram_for_correspondences +
         cluster_histogram_non_correspondences)
    stationary_inds = np.argwhere(frac > 0.01)
    np.save(os.path.join(save_folder, 'stationary_inds.npy'), stationary_inds)
    print('{} stationary clusters out of {}'.format(
        len(stationary_inds), len(cluster_histogram_for_correspondences)))
예제 #2
0
def extract_features_for_reference(net, net_config, im_list_file, im_root, ref_pts_dir, max_num_features_per_image=500, fraction_correspondeces=0.5):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    features = []

    # store mapping from original image representation to linear feature index
    # for correspondence coordinate with index i in image with index f
    # feature_ind = mapping[f][i][0]
    # x_coord = mapping[f][i][1]
    # y_coord = mapping[f][i][2]
    #
    # features[feature_ind] was taken from output_feature_map[:, y_coord, x_coord] of image f
    mapping = []

    # data loading
    input_transform = net_config.input_transform
    pre_inference_transform_with_corrs = net_config.pre_inference_transform_with_corrs
    # make sure crop size and stride same as during training
    sliding_crop = joint_transforms.SlidingCropImageOnly(713, 2/3)

    t0 = time.time()

    # get all file names

    filenames_ims = list()
    filenames_pts = list()
    for im_l, im_r, ref_pts_d in zip(im_list_file, im_root, ref_pts_dir):
        with open(im_l) as f:
            for line in f:
                filename_im = line.strip()

                filenames_ims.append(os.path.join(im_r, filename_im))
                filenames_pts.append(os.path.join(
                    ref_pts_d, im_to_ext_name(filename_im, 'h5')))

    # filenames_ims = [filenames_ims[0]]
    # filenames_pts = [filenames_pts[0]]

    iii = 0
    feature_count = 0
    for filename_im, filename_pts in zip(filenames_ims, filenames_pts):

        img = Image.open(filename_im).convert('RGB')
        f = h5py.File(filename_pts, 'r')
        pts = np.swapaxes(np.array(f['corr_pts']), 0, 1)

        # creating sliding crop windows and transform them
        if pre_inference_transform_with_corrs is not None:
            img, pts = pre_inference_transform_with_corrs(img, pts)

        img_size = img.size
        img_slices, slices_info = sliding_crop(img)
        if input_transform is not None:
            img_slices = [input_transform(e) for e in img_slices]

        img_slices = torch.stack(img_slices, 0)
        slices_info = torch.LongTensor(slices_info)
        slices_info.squeeze_(0)

        # run network on all slizes
        img_slices = img_slices.to(device)

        for ind in range(0, img_slices.size(0), 10):
            max_ind = min(ind + 10, img_slices.size(0))
            with torch.no_grad():
                out_tmp = net(img_slices[ind:max_ind, :, :, :])
            if ind == 0:
                n_features = out_tmp.size(1)
                oh = out_tmp.size(2)
                ow = out_tmp.size(3)
                scale_h = 713/oh
                scale_w = 713/ow
                output_slices = torch.zeros(
                    img_slices.size(0), n_features, oh, ow).to(device)
            output_slices[ind:max_ind] = out_tmp

        outsizeh = round(slices_info[:, 0].max().item()/scale_h) + oh
        outsizew = round(slices_info[:, 2].max().item()/scale_w) + ow
        count = torch.zeros(outsizeh, outsizew).to(device)
        output = torch.zeros(n_features, outsizeh, outsizew).to(device)
        sliding_transform_step = (2/3, 2/3)
        interpol_weight = create_interpol_weights(
            (oh, ow), sliding_transform_step)
        interpol_weight = interpol_weight.to(device)

        for output_slice, info in zip(output_slices, slices_info):
            hs = round(info[0].item()/scale_h)
            ws = round(info[2].item()/scale_w)
            output[:, hs:hs+oh, ws:ws +
                   ow] += (interpol_weight*output_slice[:, :oh, :ow]).data
            count[hs:hs+oh, ws:ws+ow] += interpol_weight
        output /= count

        # Scale correspondences coordinates to output size
        pts[0, :] = pts[0, :]*outsizew/img_size[0]
        pts[1, :] = pts[1, :]*outsizeh/img_size[1]
        pts = (pts + 0.5).astype(int)
        valid_inds = (pts[0, :] < output.size(2)) & (pts[0, :] >= 0) & (
            pts[1, :] < output.size(1)) & (pts[1, :] >= 0)
        pts = pts[:, valid_inds]
        pts = np.unique(pts, axis=1)

        if pts.shape[1] > int(max_num_features_per_image*fraction_correspondeces + 0.5):
            inds = np.random.choice(pts.shape[1], int(
                max_num_features_per_image*fraction_correspondeces + 0.5), replace=False)
            pts = pts[:, inds]

        # add some random points as well
        n_non_corr = int((1 - fraction_correspondeces) /
                         fraction_correspondeces * pts.shape[1] + 0.5)
        lin_inds = np.random.choice(output.size(
            2)*output.size(1), n_non_corr, replace=False)
        ptstoadd = np.concatenate((np.expand_dims(
            lin_inds // output.size(1), axis=0), np.expand_dims(lin_inds % output.size(1), axis=0)), axis=0)
        pts = np.hstack((pts, ptstoadd))

        # Save features
        thismap = {}
        thismap['coords'] = []
        thismap['size'] = (outsizeh, outsizew)
        for pti in range(pts.shape[1]):
            features.append(
                output[:, pts[1, pti], pts[0, pti]].cpu().numpy().astype(np.float32))
            thismap['coords'].append([feature_count, pts[0, pti], pts[1, pti]])
            feature_count += 1

        mapping.append(thismap)

        if (iii % 100) == 0:
            print('computing features: %d/%d' % (iii, len(filenames_ims)))

        iii += 1

    tend = time.time()
    print('Time: %f' % (tend-t0))

    return features, mapping
예제 #3
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
def segment_images_in_folder(network_file, img_folder, save_folder, args):

    # get current available device
    if args['use_gpu']:
        print("Using CUDA" if torch.cuda.is_available() else "Using CPU")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    if 'n_classes' in args:
        print('Initializing model with %d classes' % args['n_classes'])
        net = model_config.init_network(
            n_classes=args['n_classes'],
            for_clustering=False,
            output_features=False,
            use_original_base=args['use_original_base']).to(device)
    else:
        net = model_config.init_network().to(device)

    print('load model ' + network_file)
    state_dict = torch.load(network_file,
                            map_location=lambda storage, loc: storage)
    # needed since we slightly changed the structure of the network in pspnet
    state_dict = rename_keys_to_match(state_dict)
    net.load_state_dict(state_dict)
    net.eval()

    # data loading
    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform
    # make sure crop size and stride same as during training
    sliding_crop = joint_transforms.SlidingCropImageOnly(
        713, args['sliding_transform_step'])

    check_mkdir(save_folder)
    t0 = time.time()

    # get all file names
    filenames_ims = list()
    filenames_segs = list()
    print('Scanning %s for images to segment.' % img_folder)
    for root, subdirs, files in os.walk(img_folder):
        filenames = [f for f in files if f.endswith(args['img_ext'])]
        if len(filenames) > 0:
            print('Found %d images in %s' % (len(filenames), root))
            seg_path = root.replace(img_folder, save_folder)
            check_mkdir(seg_path)
            filenames_ims += [os.path.join(root, f) for f in filenames]
            filenames_segs += [
                os.path.join(seg_path, f.replace(args['img_ext'], '.png'))
                for f in filenames
            ]

    # Create segmentor
    if net.n_classes == 19:  # This could be the 19 cityscapes classes
        segmentor = Segmentor(net,
                              net.n_classes,
                              colorize_fcn=cityscapes.colorize_mask,
                              n_slices_per_pass=args['n_slices_per_pass'])
    else:
        segmentor = Segmentor(net,
                              net.n_classes,
                              colorize_fcn=None,
                              n_slices_per_pass=args['n_slices_per_pass'])

    count = 1
    for im_file, save_path in zip(filenames_ims, filenames_segs):
        tnow = time.time()
        print("[%d/%d (%.1fs/%.1fs)] %s" %
              (count, len(filenames_ims), tnow - t0,
               (tnow - t0) / count * len(filenames_ims), im_file))
        segmentor.run_and_save(
            im_file,
            save_path,
            pre_sliding_crop_transform=pre_validation_transform,
            sliding_crop=sliding_crop,
            input_transform=input_transform,
            skip_if_seg_exists=True,
            use_gpu=args['use_gpu'])
        count += 1

    tend = time.time()
    print('Time: %f' % (tend - t0))
예제 #5
0
    def run_and_save(
        self,
        img_path,
        save_path,
        pre_sliding_crop_transform=None,
        sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.),
        input_transform=standard_transforms.ToTensor(),
        verbose=False,
        skip_if_seg_exists=False,
        use_gpu=True,
    ):
        """
        img                  - Path of input image
        save_path             - Path of output image (feature map)
        sliding_crop         - Transform that returns set of image slices
        input_transform      - Transform to apply to image before inputting to network
        skip_if_seg_exists   - Whether to overwrite or skip if segmentation exists already
        """

        if save_path is not None:
            if os.path.exists(save_path):
                if skip_if_seg_exists:
                    if verbose:
                        print(
                            "Segmentation already exists, skipping: {}".format(
                                save_path))
                    return
                else:
                    if verbose:
                        print("Segmentation already exists, overwriting: {}".
                              format(save_path))

        if isinstance(img_path, str):
            try:
                img = Image.open(img_path).convert('RGB')
            except OSError:
                print(
                    "Error reading input image, skipping: {}".format(img_path))
        else:
            img = img_path

        # creating sliding crop windows and transform them
        img_size_orig = img.size
        if pre_sliding_crop_transform is not None:  # might reshape image
            img = pre_sliding_crop_transform(img)

        img_slices, slices_info = sliding_crop(img)
        img_slices = [input_transform(e) for e in img_slices]
        img_slices = torch.stack(img_slices, 0)
        slices_info = torch.LongTensor(slices_info)
        slices_info.squeeze_(0)

        of_pre, oa_pre = self.net.output_features, self.net.output_all
        self.net.output_features, self.net.output_all = True, False
        feature_map = self.run_on_slices(
            img_slices,
            slices_info,
            sliding_transform_step=sliding_crop.stride_rate,
            use_gpu=use_gpu)
        # restore previous settings
        self.net.output_features, self.net.output_all = of_pre, oa_pre

        if save_path is not None:
            check_mkdir(os.path.dirname(save_path))
            ext = save_path.split('.')[-1]
            if ext == 'mat':
                matdict = {
                    "features": np.transpose(feature_map, [1, 2, 0]),
                    "original_image_size": (img_size_orig[1], img_size_orig[0])
                }
                sio.savemat(save_path, matdict, appendmat=False)
            elif ext == 'npy':
                np.save(save_path, feature_map)
            else:
                raise ValueError(
                    'invalid file extension for save_path, only mat and np supported'
                )

        return feature_map
예제 #6
0
    def run_and_save(
        self,
        img_path,
        seg_path,
        pre_sliding_crop_transform=None,
        sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.),
        input_transform=standard_transforms.ToTensor(),
        verbose=False,
        skip_if_seg_exists=False,
        use_gpu=True,
    ):
        """
        img                  - Path of input image
        seg_path             - Path of output image (segmentation)
        sliding_crop         - Transform that returns set of image slices
        input_transform      - Transform to apply to image before inputting to network
        skip_if_seg_exists   - Whether to overwrite or skip if segmentation exists already
        """

        if seg_path is not None:
            if os.path.exists(seg_path):
                if skip_if_seg_exists:
                    if verbose:
                        print(
                            "Segmentation already exists, skipping: {}".format(
                                seg_path))
                    return
                else:
                    if verbose:
                        print("Segmentation already exists, overwriting: {}".
                              format(seg_path))

        if isinstance(img_path, str):
            try:
                img = Image.open(img_path).convert('RGB')
            except OSError:
                print(
                    "Error reading input image, skipping: {}".format(img_path))
                return
        else:
            img = img_path

        # creating sliding crop windows and transform them
        img_size_orig = img.size
        if pre_sliding_crop_transform is not None:  # might reshape image
            img = pre_sliding_crop_transform(img)

        img_slices, slices_info = sliding_crop(img)
        img_slices = [input_transform(e) for e in img_slices]
        img_slices = torch.stack(img_slices, 0)
        slices_info = torch.LongTensor(slices_info)
        slices_info.squeeze_(0)

        prediction_logits = self.run_on_slices(
            img_slices,
            slices_info,
            sliding_transform_step=sliding_crop.stride_rate,
            use_gpu=use_gpu,
            return_logits=True)
        prediction_orig = prediction_logits.max(0)[1].squeeze_(0).numpy()
        prediction_logits = prediction_logits.numpy()

        if self.colorize_fcn:
            prediction_colorized = self.colorize_fcn(prediction_orig)
        else:
            prediction_colorized = Image.fromarray(
                prediction_orig.astype(np.int32)).convert('I')

        if prediction_colorized.size != img_size_orig:
            prediction_colorized = F.resize(prediction_colorized,
                                            img_size_orig[::-1],
                                            interpolation=Image.NEAREST)

        if seg_path is not None:
            check_mkdir(os.path.dirname(seg_path))
            prediction_colorized.save(seg_path)

        return prediction_colorized