Пример #1
0
    def __init__(self, log_path, resume=True):
        self.log_path = log_path
        check_mkdir(dir_path=os.path.dirname(self.log_path))

        if not resume:
            open(self.log_path, encoding="utf-8", mode="w").close()
        self.record_str(f"=== {datetime.now()} ===")
Пример #2
0
def main():
    check_mkdir(path_config["save"])
    check_mkdir(path_config["pth"])
    # shutil.copy(f"{user_config['proj_root']}/config.py", path_config["cfg_log"])
    # shutil.copy(f"{user_config['proj_root']}/train.py", path_config["trainer_log"])

    if user_config["is_distributed"]:
        construct_print("We will use the distributed training.")
        args.world_size = args.ngpus_per_node * args.nodes
        torch.multiprocessing.spawn(
            main_worker,
            nprocs=args.ngpus_per_node,
            args=(args.ngpus_per_node, args.world_size),
        )
    else:
        construct_print("We will not use the distributed training.")
        main_worker(
            local_rank=0,
            ngpus_per_node=1,
            world_size=1,
        )
    tb_recorder.close_tb()
Пример #3
0
    def __init__(self, tb_path):
        check_mkdir(tb_path)

        self.tb = SummaryWriter(tb_path)
Пример #4
0
def validate(epoch):
    net.eval()
    val_loss = AverageMeter()
    inputs_all, labels_all, predictions_all = [], [], []

    for i, (inputs, labels) in enumerate(valloader):
        if args.cuda:
            inputs, labels = inputs.cuda(), labels.cuda()
        N = inputs.size(0)
        outputs = net(inputs)
        # predictions 为输入图片尺寸大小的对应每一像素点的分类值
        predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()

        loss = criterion(outputs, labels) / N
        val_loss.update(loss.item(), N)

        if random.random() > args.valImgSampleRate:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.data.squeeze_(0).cpu())

        labels_all.append(labels.data.squeeze_(0).cpu().numpy())

        predictions_all.append(predictions)

    # 计算本次epoch之后对验证集的正确率等评价指标
    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, labels_all,
                                              num_classes)

    if mean_iu > best_record['mean_iu']:
        best_record['val_loss'] = val_loss.avg
        best_record['epoch'] = epoch
        best_record['acc'] = acc
        best_record['acc_cls'] = acc_cls
        best_record['mean_iu'] = mean_iu
        best_record['fwavacc'] = fwavacc
        snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
            epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, args.lr)
        torch.save(net.state_dict(),
                   os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
        #torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        if args.val_save_to_img_file:
            to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
            check_mkdir(to_save_dir)

        #val_visual = []
        for idx, data in enumerate(zip(inputs_all, labels_all,
                                       predictions_all)):
            if data[0] is None:
                continue
            input_pil = restore_transform(data[0])
            labels_pil = colorize_mask(data[1])
            predictions_pil = colorize_mask(data[2])
            if args.val_save_to_img_file:
                input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx))
                predictions_pil.save(
                    os.path.join(to_save_dir, '%d_prediction.png' % idx))
                labels_pil.save(os.path.join(to_save_dir,
                                             '%d_label.png' % idx))
            # val_visual.extend([visualize(input_pil.convert('RGB')), visualize(labels_pil.convert('RGB')),
            #                    visualize(predictions_pil.convert('RGB'))])
        # val_visual = torch.stack(val_visual, 0)
        # val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)
        # writer.add_image(snapshot_name, val_visual)

    print(
        '--------------------------------------------------------------------')
    print(
        '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]'
        % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print(
        'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]'
        %
        (best_record['val_loss'], best_record['acc'], best_record['acc_cls'],
         best_record['mean_iu'], best_record['fwavacc'], best_record['epoch']))

    print(
        '--------------------------------------------------------------------')

    # writer.add_scalar('val_loss', val_loss.avg, epoch)
    # writer.add_scalar('acc', acc, epoch)
    # writer.add_scalar('acc_cls', acc_cls, epoch)
    # writer.add_scalar('mean_iu', mean_iu, epoch)
    # writer.add_scalar('fwavacc', fwavacc, epoch)
    # writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch)

    return val_loss.avg
Пример #5
0
# betas=(args.momentum, 0.999))

scheduler = ReduceLROnPlateau(optimizer,
                              'min',
                              patience=args.lr_patience,
                              min_lr=1e-10,
                              verbose=True)

# resume 优化器参数加载
if len(args.resume) > 0:
    optimizer.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args.resume)))
    optimizer.param_groups[0]['lr'] = 2 * args.lr
    optimizer.param_groups[1]['lr'] = args.lr

check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
#open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')


def train(epoch):
    net.train()
    # 计算平均损失,每个epoch更新为0
    train_loss = AverageMeter()
    #每次迭代调用 _getitem_ 方法,进行transform变换。
    curr_iter = (epoch - 1) * len(trainloader)
    for i, (inputs, labels) in enumerate(trainloader):
        if args.cuda:
            inputs, labels = inputs.cuda(), labels.cuda()
        N = inputs.size(0)
        # 清空梯度
Пример #6
0
def find_non_stationary_clusters(args):
    if args['use_gpu']:
        print("Using CUDA" if torch.cuda.is_available() else "Using CPU")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    network_folder_name = args['folder'].split('/')[-1]
    tmp = re.search(r"cn(\d+)-", network_folder_name)
    n_clusters = int(tmp.group(1))

    save_folder = os.path.join(args['dest_root'], network_folder_name)
    if os.path.exists(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')):
        print('{} already exists. skipping'.format(
            os.path.join(save_folder, 'cluster_histogram_for_corr.npy')))
        return

    check_mkdir(save_folder)

    with open(os.path.join(args['folder'], 'bestval.txt')) as f:
        best_val_dict_str = f.read()
        bestval = eval(best_val_dict_str.rstrip())

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=n_clusters,
        for_clustering=False,
        output_features=False,
        use_original_base=args['use_original_base']).to(device)
    net.load_state_dict(
        torch.load(os.path.join(args['folder'],
                                bestval['snapshot'] + '.pth')))  # load weights
    net.eval()

    # copy network file to save location
    copyfile(os.path.join(args['folder'], bestval['snapshot'] + '.pth'),
             os.path.join(save_folder, 'weights.pth'))

    if args['only_copy_weights']:
        print('Only copying weights')
        return

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config1.correspondence_val_list_file)
        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config2.correspondence_val_list_file)
        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=None,
            joint_transform=None,
            listfile=corr_set_config.correspondence_val_list_file)

    # Segmentor
    segmentor = Segmentor(net, n_clusters, n_slices_per_pass=4)

    # save args
    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    cluster_histogram_for_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)
    cluster_histogram_non_correspondences = np.zeros((n_clusters, ),
                                                     dtype=np.int64)

    for i in range(0, len(corr_set_val), args['step']):
        img1, img2, pts1, pts2, _ = corr_set_val[i]
        seg1 = segmentor.run_and_save(
            img1,
            None,
            pre_sliding_crop_transform=pre_validation_transform,
            input_transform=input_transform,
            sliding_crop=sliding_crop_im,
            use_gpu=args['use_gpu'])

        seg1 = np.array(seg1)
        corr_loc_mask = np.zeros(seg1.shape, dtype=np.bool)

        valid_inds = (pts1[0, :] >= 0) & (pts1[0, :] < seg1.shape[1]) & (
            pts1[1, :] >= 0) & (pts1[1, :] < seg1.shape[0])

        pts1 = pts1[:, valid_inds]
        for j in range(pts1.shape[1]):
            pt = pts1[:, j]
            corr_loc_mask[pt[1], pt[0]] = True

        cluster_ids_corr = seg1[corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_corr, np.arange(n_clusters + 1))
        cluster_histogram_for_correspondences += hist_tmp

        cluster_ids_no_corr = seg1[~corr_loc_mask]
        hist_tmp, _ = np.histogram(cluster_ids_no_corr,
                                   np.arange(n_clusters + 1))
        cluster_histogram_non_correspondences += hist_tmp

        if ((i + 1) % 100) < args['step']:
            print('{}/{}'.format(i + 1, len(corr_set_val)))

    np.save(os.path.join(save_folder, 'cluster_histogram_for_corr.npy'),
            cluster_histogram_for_correspondences)
    np.save(os.path.join(save_folder, 'cluster_histogram_non_corr.npy'),
            cluster_histogram_non_correspondences)
    frac = cluster_histogram_for_correspondences / \
        (cluster_histogram_for_correspondences +
         cluster_histogram_non_correspondences)
    stationary_inds = np.argwhere(frac > 0.01)
    np.save(os.path.join(save_folder, 'stationary_inds.npy'), stationary_inds)
    print('{} stationary clusters out of {}'.format(
        len(stationary_inds), len(cluster_histogram_for_correspondences)))
Пример #7
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
def segment_images_in_folder(network_file, img_folder, save_folder, args):

    # get current available device
    if args['use_gpu']:
        print("Using CUDA" if torch.cuda.is_available() else "Using CPU")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    if 'n_classes' in args:
        print('Initializing model with %d classes' % args['n_classes'])
        net = model_config.init_network(
            n_classes=args['n_classes'],
            for_clustering=False,
            output_features=False,
            use_original_base=args['use_original_base']).to(device)
    else:
        net = model_config.init_network().to(device)

    print('load model ' + network_file)
    state_dict = torch.load(network_file,
                            map_location=lambda storage, loc: storage)
    # needed since we slightly changed the structure of the network in pspnet
    state_dict = rename_keys_to_match(state_dict)
    net.load_state_dict(state_dict)
    net.eval()

    # data loading
    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform
    # make sure crop size and stride same as during training
    sliding_crop = joint_transforms.SlidingCropImageOnly(
        713, args['sliding_transform_step'])

    check_mkdir(save_folder)
    t0 = time.time()

    # get all file names
    filenames_ims = list()
    filenames_segs = list()
    print('Scanning %s for images to segment.' % img_folder)
    for root, subdirs, files in os.walk(img_folder):
        filenames = [f for f in files if f.endswith(args['img_ext'])]
        if len(filenames) > 0:
            print('Found %d images in %s' % (len(filenames), root))
            seg_path = root.replace(img_folder, save_folder)
            check_mkdir(seg_path)
            filenames_ims += [os.path.join(root, f) for f in filenames]
            filenames_segs += [
                os.path.join(seg_path, f.replace(args['img_ext'], '.png'))
                for f in filenames
            ]

    # Create segmentor
    if net.n_classes == 19:  # This could be the 19 cityscapes classes
        segmentor = Segmentor(net,
                              net.n_classes,
                              colorize_fcn=cityscapes.colorize_mask,
                              n_slices_per_pass=args['n_slices_per_pass'])
    else:
        segmentor = Segmentor(net,
                              net.n_classes,
                              colorize_fcn=None,
                              n_slices_per_pass=args['n_slices_per_pass'])

    count = 1
    for im_file, save_path in zip(filenames_ims, filenames_segs):
        tnow = time.time()
        print("[%d/%d (%.1fs/%.1fs)] %s" %
              (count, len(filenames_ims), tnow - t0,
               (tnow - t0) / count * len(filenames_ims), im_file))
        segmentor.run_and_save(
            im_file,
            save_path,
            pre_sliding_crop_transform=pre_validation_transform,
            sliding_crop=sliding_crop,
            input_transform=input_transform,
            skip_if_seg_exists=True,
            use_gpu=args['use_gpu'])
        count += 1

    tend = time.time()
    print('Time: %f' % (tend - t0))
Пример #9
0
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    state_dict = torch.load(startnet)
    if 'resnet101' in startnet:
        load_resnet101_weights(net, state_dict)
    else:
        # needed since we slightly changed the structure of the network in pspnet
        state_dict = rename_keys_to_match(state_dict)
        # different amount of classes
        init_last_layers(state_dict, args['n_clusters'])

        net.load_state_dict(state_dict)  # load original weights

    start_iter = 0
    args['best_record'] = {
        'iter': 0,
        'val_loss_feat': 1e10,
        'val_loss_out': 1e10,
        'val_loss_cluster': 1e10
    }

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'pola':
        corr_set_config = data_configs.PolaConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    ref_image_lists = corr_set_config.reference_image_list

    # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True)
    # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}')
    # print(corr_set_config)
    # corr_im_paths = [corr_set_config.correspondence_im_path]
    # ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path,
    #                                                 corr_set_config.correspondence_im_path,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform_corr,
    #                                                 listfile=corr_set_config.correspondence_train_list_file)
    scales = [0, 1, 2, 3]

    # corr_set_train = Poladata.MonoDataset(corr_set_config,
    #                                       seg_folder = "media/HDD1/NsemSEG/Result_fold/" ,
    #                                       im_file_ending = ".jpg" )

    train_joint_transform = joint_transforms.Compose([
        # train_joint_transform_corr = corr_transforms.Compose([
        # corr_transforms.CorrResize(1024),
        # corr_transforms.CorrRandomCrop(713)
        joint_transforms.Resize(1024),
        joint_transforms.RandomCrop(713)
    ])

    sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255)

    # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder,
    #                                                 corr_set_config.train_im_folder,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform,
    #                                                 listfile=None)

    corr_set_train = Poladata.MonoDataset(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        im_file_ending=".jpg",
        id_to_trainid=None,
        joint_transform=train_joint_transform,
        sliding_crop=sliding_crop,
        transform=input_transform,
        target_transform=None,  #train_joint_transform,
        transform_before_sliding=None  #sliding_crop
    )
    #print (corr_set_train)
    # print(corr_set_train.mask)
    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)
    # corr_loader_train = input_transform(corr_loader_train)

    # print(corr_loader_train)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)

    # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter)
    # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:

        net.eval()
        net.output_features = True
        # max_num_features_per_image = args['max_features_per_image']
        # print('-----------------------------------------------------------------')
        # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ')
        # print('-----------------------------------------------------------------')

        # print('le next du loader es : ---------------')
        # print(next(iter(corr_loader_train)))

        # features, _ = extract_features_for_reference(net, model_config, ref_image_lists,
        #                                              corr_im_paths, ref_featurs_pos,
        #                                              max_num_features_per_image=args['max_features_per_image'],
        #                                              fraction_correspondeces=0.5)
        print(
            'ici on a la len de la ref im list --------------------------------------------------------'
        )
        print(len(ref_image_lists))
        features = extract_features_for_reference_nocorr(
            net,
            model_config,
            corr_set_train,
            10,
            max_num_features_per_image=args['max_features_per_image'])

        cluster_features = np.vstack(features)
        del features

        # cluster the features
        cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
            cluster_features, verbose=True, use_gpu=False)

        # save cluster centroids
        h5f = h5py.File(
            os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
        h5f.create_dataset('cluster_centroids', data=cluster_centroids)
        h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
        h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
        h5f.close()

        # Print distribution of clusters
        cluster_distribution, _ = np.histogram(
            cluster_indices,
            bins=np.arange(args['n_clusters'] + 1),
            density=True)
        str2write = 'cluster distribution ' + \
            np.array2string(cluster_distribution, formatter={
                            'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
        print(str2write)
        f_handle.write(str2write + "\n")

        # set last layer weight to a normal distribution
        reinit_last_layers(net)

        # make a copy of current network state to do cluster assignment
        net_for_clustering = copy.deepcopy(net)

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            print(
                'on rentre dans la boucle extract cluster_______________________________________________'
            )
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True

            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train))
                corr_loader_train = input_transform(corr_loader_train)
                print(
                    f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}'
                )
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # print('le next du loader es : ---------------')
                # print(next(iter(corr_loader_train)))
                # print(img_ref)

                # Transfer data to device
                img_ref = img_ref.to(device)

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                outputs_ref, aux_ref = net(img_ref)

                seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                loss = args['seg_loss_weight'] * \
                    (seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)
    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        if 'resnet101' in startnet:
            load_resnet101_weights(net, state_dict)
        else:
            # needed since we slightly changed the structure of the network in pspnet
            state_dict = rename_keys_to_match(state_dict)
            init_last_layers(state_dict,
                             args['n_clusters'])  # different amount of classes

            net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss_feat': 1e10,
            'val_loss_out': 1e10,
            'val_loss_cluster': 1e10
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

        if start_iter >= args['max_iter']:
            return

        if (start_iter % args['cluster_interval']) == 0:
            skip_clustering = False
        else:
            skip_clustering = True
            last_cluster_network_snapshot_iter = (
                start_iter //
                args['cluster_interval']) * args['cluster_interval']

            # load cluster info
            cluster_info = {}
            f = h5py.File(
                os.path.join(
                    save_folder, 'centroids_{}.h5'.format(
                        last_cluster_network_snapshot_iter)), 'r')
            for k, v in f.items():
                cluster_info[k] = np.array(v)
            cluster_centroids = cluster_info['cluster_centroids']
            pca_info = [
                cluster_info['pca_transform_Amat'],
                cluster_info['pca_transform_bvec']
            ]

            # load network that was used for last clustering
            net_for_clustering = model_config.init_network(
                n_classes=args['n_clusters'],
                for_clustering=True,
                output_features=True,
                use_original_base=args['use_original_base'])

            if last_cluster_network_snapshot_iter == 0:
                state_dict = torch.load(
                    startnet, map_location=lambda storage, loc: storage)
                if 'resnet101' in startnet:
                    load_resnet101_weights(net_for_clustering, state_dict)
                else:
                    # needed since we slightly changed the structure of the network in pspnet
                    state_dict = rename_keys_to_match(state_dict)
                    init_last_layers(
                        state_dict,
                        args['n_clusters'])  # different amount of classes

                    net_for_clustering.load_state_dict(
                        state_dict)  # load original weights
            else:
                cluster_network_weights = get_network_name_from_iteration(
                    save_folder, last_cluster_network_snapshot_iter)
                net_for_clustering.load_state_dict(
                    torch.load(os.path.join(save_folder,
                                            cluster_network_weights),
                               map_location=lambda storage, loc: storage)
                )  # load weights

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    if args['corr_set'] == 'both':
        ref_image_lists = [
            corr_set_config1.reference_image_list,
            corr_set_config2.reference_image_list
        ]
        corr_im_paths = [
            corr_set_config1.correspondence_im_path,
            corr_set_config2.correspondence_im_path
        ]
        ref_featurs_pos = [
            corr_set_config1.reference_feature_poitions,
            corr_set_config2.reference_feature_poitions
        ]
    else:
        ref_image_lists = [corr_set_config.reference_image_list]
        corr_im_paths = [corr_set_config.correspondence_im_path]
        ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # Correspondences for training
    if args['corr_set'] == 'both':
        corr_set_train1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_train_list_file)
        corr_set_train2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_train_list_file)

        corr_set_train = merged.Merged([corr_set_train1, corr_set_train2])
    else:
        corr_set_train = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_train_list_file)

    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)

    # Correspondences for validation
    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_val_list_file)

        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_val_list_file)

        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_val_list_file)

    corr_loader_val = DataLoader(corr_set_val,
                                 batch_size=1,
                                 num_workers=args['n_workers'],
                                 shuffle=False)

    # Loss setup
    val_corr_loss_fct_feat = FeatureLoss(
        input_size=[713, 713],
        loss_type=args['feature_distance_measure'],
        feat_dist_threshold_match=0.8,
        feat_dist_threshold_nomatch=0.2,
        n_not_matching=0)

    val_corr_loss_fct_out = FeatureLoss(input_size=[713, 713],
                                        loss_type='KL',
                                        feat_dist_threshold_match=0.8,
                                        feat_dist_threshold_nomatch=0.2,
                                        n_not_matching=0)

    loss_fct = ClusterCorrespondenceLoss(input_size=[713, 713],
                                         size_average=True).to(device)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    if args['feature_hinge_loss_weight'] > 0:
        feature_loss_fct = FeatureLoss(input_size=[713, 713],
                                       loss_type='hingeF',
                                       feat_dist_threshold_match=0.8,
                                       feat_dist_threshold_nomatch=0.2,
                                       n_not_matching=0)

    # Validator
    corr_validator = CorrValidator(corr_loader_val,
                                   val_corr_loss_fct_feat,
                                   val_corr_loss_fct_out,
                                   loss_fct,
                                   save_snapshot=True,
                                   extra_name_str='Corr')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:
        if not skip_clustering:
            # Extract image features from reference images
            net.eval()
            net.output_features = True

            features, _ = extract_features_for_reference(
                net,
                model_config,
                ref_image_lists,
                corr_im_paths,
                ref_featurs_pos,
                max_num_features_per_image=args['max_features_per_image'],
                fraction_correspondeces=0.5)

            cluster_features = np.vstack(features)
            del features

            # cluster the features
            cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
                cluster_features, verbose=True, use_gpu=False)

            # save cluster centroids
            h5f = h5py.File(
                os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
            h5f.create_dataset('cluster_centroids', data=cluster_centroids)
            h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
            h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
            h5f.close()

            # Print distribution of clusters
            cluster_distribution, _ = np.histogram(
                cluster_indices,
                bins=np.arange(args['n_clusters'] + 1),
                density=True)
            str2write = 'cluster distribution ' + \
                np.array2string(cluster_distribution, formatter={'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
            print(str2write)
            f_handle.write(str2write + "\n")

            reinit_last_layers(
                net)  # set last layer weight to a normal distribution

            # make a copy of current network state to do cluster assignment
            net_for_clustering = copy.deepcopy(net)
        else:
            skip_clustering = False

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True
            if args['feature_hinge_loss_weight'] > 0:
                net_for_clustering.output_all = False
            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                img_other = img_other.cpu()
                pts_ref = [p.cpu() for p in pts_ref]
                pts_other = [p.cpu() for p in pts_other]
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, img_other, pts_ref, pts_other,
                                     cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, img_other, pts_ref, pts_other, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = True

                # Randomization to decide if reference or target image should be used for training
                if args['fraction_reference_bp'] is None:  # use both
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                elif np.random.rand(
                        1)[0] < args['fraction_reference_bp']:  # use reference
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                                img_other)
                        else:
                            outputs_other, aux_other = net(img_other)
                else:  # use target
                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                                img_ref)
                        else:
                            outputs_ref, aux_ref = net(img_ref)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                    seg_main_loss = 0.
                    seg_aux_loss = 0.

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = False

                main_loss, _ = loss_fct(outputs_ref,
                                        outputs_other,
                                        None,
                                        pts_ref,
                                        pts_other,
                                        cluster_labels=cluster_labels)
                aux_loss, _ = loss_fct(aux_ref,
                                       aux_other,
                                       None,
                                       pts_ref,
                                       pts_other,
                                       cluster_labels=cluster_labels)

                if args['feature_hinge_loss_weight'] > 0:
                    feature_loss = feature_loss_fct(out_feat_ref,
                                                    out_feat_other, pts_ref,
                                                    pts_other, None)
                    feature_loss_aux = feature_loss_fct(
                        aux_feat_ref, aux_feat_other, pts_ref, pts_other, None)
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss) \
                        + args['feature_hinge_loss_weight']*(feature_loss + 0.4 * feature_loss_aux)
                else:
                    feature_loss = 0.
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + \
                        args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                corr_train_loss.update(main_loss.item(), 1)
                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)
                if type(feature_loss) == torch.Tensor:
                    feature_train_loss.update(feature_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_corr_loss', corr_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_feature_loss', feature_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        corr_train_loss.avg, feature_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if val_iter >= args['val_interval']:
                    val_iter = 0
                    net_for_clustering.to(device)
                    corr_validator.run(net,
                                       net_for_clustering,
                                       (deepcluster, pca_info),
                                       optimizer,
                                       args,
                                       curr_iter,
                                       save_folder,
                                       f_handle,
                                       writer=writer)
                    net_for_clustering.to("cpu")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
Пример #11
0
    def run_and_save(
        self,
        img_path,
        save_path,
        pre_sliding_crop_transform=None,
        sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.),
        input_transform=standard_transforms.ToTensor(),
        verbose=False,
        skip_if_seg_exists=False,
        use_gpu=True,
    ):
        """
        img                  - Path of input image
        save_path             - Path of output image (feature map)
        sliding_crop         - Transform that returns set of image slices
        input_transform      - Transform to apply to image before inputting to network
        skip_if_seg_exists   - Whether to overwrite or skip if segmentation exists already
        """

        if save_path is not None:
            if os.path.exists(save_path):
                if skip_if_seg_exists:
                    if verbose:
                        print(
                            "Segmentation already exists, skipping: {}".format(
                                save_path))
                    return
                else:
                    if verbose:
                        print("Segmentation already exists, overwriting: {}".
                              format(save_path))

        if isinstance(img_path, str):
            try:
                img = Image.open(img_path).convert('RGB')
            except OSError:
                print(
                    "Error reading input image, skipping: {}".format(img_path))
        else:
            img = img_path

        # creating sliding crop windows and transform them
        img_size_orig = img.size
        if pre_sliding_crop_transform is not None:  # might reshape image
            img = pre_sliding_crop_transform(img)

        img_slices, slices_info = sliding_crop(img)
        img_slices = [input_transform(e) for e in img_slices]
        img_slices = torch.stack(img_slices, 0)
        slices_info = torch.LongTensor(slices_info)
        slices_info.squeeze_(0)

        of_pre, oa_pre = self.net.output_features, self.net.output_all
        self.net.output_features, self.net.output_all = True, False
        feature_map = self.run_on_slices(
            img_slices,
            slices_info,
            sliding_transform_step=sliding_crop.stride_rate,
            use_gpu=use_gpu)
        # restore previous settings
        self.net.output_features, self.net.output_all = of_pre, oa_pre

        if save_path is not None:
            check_mkdir(os.path.dirname(save_path))
            ext = save_path.split('.')[-1]
            if ext == 'mat':
                matdict = {
                    "features": np.transpose(feature_map, [1, 2, 0]),
                    "original_image_size": (img_size_orig[1], img_size_orig[0])
                }
                sio.savemat(save_path, matdict, appendmat=False)
            elif ext == 'npy':
                np.save(save_path, feature_map)
            else:
                raise ValueError(
                    'invalid file extension for save_path, only mat and np supported'
                )

        return feature_map
Пример #12
0
    def run_and_save(
        self,
        img_path,
        seg_path,
        pre_sliding_crop_transform=None,
        sliding_crop=joint_transforms.SlidingCropImageOnly(713, 2 / 3.),
        input_transform=standard_transforms.ToTensor(),
        verbose=False,
        skip_if_seg_exists=False,
        use_gpu=True,
    ):
        """
        img                  - Path of input image
        seg_path             - Path of output image (segmentation)
        sliding_crop         - Transform that returns set of image slices
        input_transform      - Transform to apply to image before inputting to network
        skip_if_seg_exists   - Whether to overwrite or skip if segmentation exists already
        """

        if seg_path is not None:
            if os.path.exists(seg_path):
                if skip_if_seg_exists:
                    if verbose:
                        print(
                            "Segmentation already exists, skipping: {}".format(
                                seg_path))
                    return
                else:
                    if verbose:
                        print("Segmentation already exists, overwriting: {}".
                              format(seg_path))

        if isinstance(img_path, str):
            try:
                img = Image.open(img_path).convert('RGB')
            except OSError:
                print(
                    "Error reading input image, skipping: {}".format(img_path))
                return
        else:
            img = img_path

        # creating sliding crop windows and transform them
        img_size_orig = img.size
        if pre_sliding_crop_transform is not None:  # might reshape image
            img = pre_sliding_crop_transform(img)

        img_slices, slices_info = sliding_crop(img)
        img_slices = [input_transform(e) for e in img_slices]
        img_slices = torch.stack(img_slices, 0)
        slices_info = torch.LongTensor(slices_info)
        slices_info.squeeze_(0)

        prediction_logits = self.run_on_slices(
            img_slices,
            slices_info,
            sliding_transform_step=sliding_crop.stride_rate,
            use_gpu=use_gpu,
            return_logits=True)
        prediction_orig = prediction_logits.max(0)[1].squeeze_(0).numpy()
        prediction_logits = prediction_logits.numpy()

        if self.colorize_fcn:
            prediction_colorized = self.colorize_fcn(prediction_orig)
        else:
            prediction_colorized = Image.fromarray(
                prediction_orig.astype(np.int32)).convert('I')

        if prediction_colorized.size != img_size_orig:
            prediction_colorized = F.resize(prediction_colorized,
                                            img_size_orig[::-1],
                                            interpolation=Image.NEAREST)

        if seg_path is not None:
            check_mkdir(os.path.dirname(seg_path))
            prediction_colorized.save(seg_path)

        return prediction_colorized