Exemple #1
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    state_dict = torch.load(startnet)
    if 'resnet101' in startnet:
        load_resnet101_weights(net, state_dict)
    else:
        # needed since we slightly changed the structure of the network in pspnet
        state_dict = rename_keys_to_match(state_dict)
        # different amount of classes
        init_last_layers(state_dict, args['n_clusters'])

        net.load_state_dict(state_dict)  # load original weights

    start_iter = 0
    args['best_record'] = {
        'iter': 0,
        'val_loss_feat': 1e10,
        'val_loss_out': 1e10,
        'val_loss_cluster': 1e10
    }

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'pola':
        corr_set_config = data_configs.PolaConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    ref_image_lists = corr_set_config.reference_image_list

    # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True)
    # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}')
    # print(corr_set_config)
    # corr_im_paths = [corr_set_config.correspondence_im_path]
    # ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path,
    #                                                 corr_set_config.correspondence_im_path,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform_corr,
    #                                                 listfile=corr_set_config.correspondence_train_list_file)
    scales = [0, 1, 2, 3]

    # corr_set_train = Poladata.MonoDataset(corr_set_config,
    #                                       seg_folder = "media/HDD1/NsemSEG/Result_fold/" ,
    #                                       im_file_ending = ".jpg" )

    train_joint_transform = joint_transforms.Compose([
        # train_joint_transform_corr = corr_transforms.Compose([
        # corr_transforms.CorrResize(1024),
        # corr_transforms.CorrRandomCrop(713)
        joint_transforms.Resize(1024),
        joint_transforms.RandomCrop(713)
    ])

    sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255)

    # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder,
    #                                                 corr_set_config.train_im_folder,
    #                                                 input_size=(713, 713),
    #                                                 input_transform=input_transform,
    #                                                 joint_transform=train_joint_transform,
    #                                                 listfile=None)

    corr_set_train = Poladata.MonoDataset(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        im_file_ending=".jpg",
        id_to_trainid=None,
        joint_transform=train_joint_transform,
        sliding_crop=sliding_crop,
        transform=input_transform,
        target_transform=None,  #train_joint_transform,
        transform_before_sliding=None  #sliding_crop
    )
    #print (corr_set_train)
    # print(corr_set_train.mask)
    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)
    # corr_loader_train = input_transform(corr_loader_train)

    # print(corr_loader_train)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)

    # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter)
    # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:

        net.eval()
        net.output_features = True
        # max_num_features_per_image = args['max_features_per_image']
        # print('-----------------------------------------------------------------')
        # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ')
        # print('-----------------------------------------------------------------')

        # print('le next du loader es : ---------------')
        # print(next(iter(corr_loader_train)))

        # features, _ = extract_features_for_reference(net, model_config, ref_image_lists,
        #                                              corr_im_paths, ref_featurs_pos,
        #                                              max_num_features_per_image=args['max_features_per_image'],
        #                                              fraction_correspondeces=0.5)
        print(
            'ici on a la len de la ref im list --------------------------------------------------------'
        )
        print(len(ref_image_lists))
        features = extract_features_for_reference_nocorr(
            net,
            model_config,
            corr_set_train,
            10,
            max_num_features_per_image=args['max_features_per_image'])

        cluster_features = np.vstack(features)
        del features

        # cluster the features
        cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
            cluster_features, verbose=True, use_gpu=False)

        # save cluster centroids
        h5f = h5py.File(
            os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
        h5f.create_dataset('cluster_centroids', data=cluster_centroids)
        h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
        h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
        h5f.close()

        # Print distribution of clusters
        cluster_distribution, _ = np.histogram(
            cluster_indices,
            bins=np.arange(args['n_clusters'] + 1),
            density=True)
        str2write = 'cluster distribution ' + \
            np.array2string(cluster_distribution, formatter={
                            'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
        print(str2write)
        f_handle.write(str2write + "\n")

        # set last layer weight to a normal distribution
        reinit_last_layers(net)

        # make a copy of current network state to do cluster assignment
        net_for_clustering = copy.deepcopy(net)

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            print(
                'on rentre dans la boucle extract cluster_______________________________________________'
            )
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True

            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train))
                corr_loader_train = input_transform(corr_loader_train)
                print(
                    f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}'
                )
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # print('le next du loader es : ---------------')
                # print(next(iter(corr_loader_train)))
                # print(img_ref)

                # Transfer data to device
                img_ref = img_ref.to(device)

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                outputs_ref, aux_ref = net(img_ref)

                seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                loss = args['seg_loss_weight'] * \
                    (seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args):
    print(save_folder.split('/')[-1])
    skip_clustering = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)
    check_mkdir(tmp_seg_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network(
        n_classes=args['n_clusters'],
        for_clustering=True,
        output_features=True,
        use_original_base=args['use_original_base']).to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)
    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        if 'resnet101' in startnet:
            load_resnet101_weights(net, state_dict)
        else:
            # needed since we slightly changed the structure of the network in pspnet
            state_dict = rename_keys_to_match(state_dict)
            init_last_layers(state_dict,
                             args['n_clusters'])  # different amount of classes

            net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss_feat': 1e10,
            'val_loss_out': 1e10,
            'val_loss_cluster': 1e10
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

        if start_iter >= args['max_iter']:
            return

        if (start_iter % args['cluster_interval']) == 0:
            skip_clustering = False
        else:
            skip_clustering = True
            last_cluster_network_snapshot_iter = (
                start_iter //
                args['cluster_interval']) * args['cluster_interval']

            # load cluster info
            cluster_info = {}
            f = h5py.File(
                os.path.join(
                    save_folder, 'centroids_{}.h5'.format(
                        last_cluster_network_snapshot_iter)), 'r')
            for k, v in f.items():
                cluster_info[k] = np.array(v)
            cluster_centroids = cluster_info['cluster_centroids']
            pca_info = [
                cluster_info['pca_transform_Amat'],
                cluster_info['pca_transform_bvec']
            ]

            # load network that was used for last clustering
            net_for_clustering = model_config.init_network(
                n_classes=args['n_clusters'],
                for_clustering=True,
                output_features=True,
                use_original_base=args['use_original_base'])

            if last_cluster_network_snapshot_iter == 0:
                state_dict = torch.load(
                    startnet, map_location=lambda storage, loc: storage)
                if 'resnet101' in startnet:
                    load_resnet101_weights(net_for_clustering, state_dict)
                else:
                    # needed since we slightly changed the structure of the network in pspnet
                    state_dict = rename_keys_to_match(state_dict)
                    init_last_layers(
                        state_dict,
                        args['n_clusters'])  # different amount of classes

                    net_for_clustering.load_state_dict(
                        state_dict)  # load original weights
            else:
                cluster_network_weights = get_network_name_from_iteration(
                    save_folder, last_cluster_network_snapshot_iter)
                net_for_clustering.load_state_dict(
                    torch.load(os.path.join(save_folder,
                                            cluster_network_weights),
                               map_location=lambda storage, loc: storage)
                )  # load weights

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()
    elif args['corr_set'] == 'both':
        corr_set_config1 = data_configs.CmuConfig()
        corr_set_config2 = data_configs.RobotcarConfig()

    if args['corr_set'] == 'both':
        ref_image_lists = [
            corr_set_config1.reference_image_list,
            corr_set_config2.reference_image_list
        ]
        corr_im_paths = [
            corr_set_config1.correspondence_im_path,
            corr_set_config2.correspondence_im_path
        ]
        ref_featurs_pos = [
            corr_set_config1.reference_feature_poitions,
            corr_set_config2.reference_feature_poitions
        ]
    else:
        ref_image_lists = [corr_set_config.reference_image_list]
        corr_im_paths = [corr_set_config.correspondence_im_path]
        ref_featurs_pos = [corr_set_config.reference_feature_poitions]

    input_transform = model_config.input_transform

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # Correspondences for training
    if args['corr_set'] == 'both':
        corr_set_train1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_train_list_file)
        corr_set_train2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_train_list_file)

        corr_set_train = merged.Merged([corr_set_train1, corr_set_train2])
    else:
        corr_set_train = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_train_list_file)

    corr_loader_train = DataLoader(corr_set_train,
                                   batch_size=1,
                                   num_workers=args['n_workers'],
                                   shuffle=True)

    # Correspondences for validation
    if args['corr_set'] == 'both':
        corr_set_val1 = correspondences.Correspondences(
            corr_set_config1.correspondence_path,
            corr_set_config1.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config1.correspondence_val_list_file)

        corr_set_val2 = correspondences.Correspondences(
            corr_set_config2.correspondence_path,
            corr_set_config2.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config2.correspondence_val_list_file)

        corr_set_val = merged.Merged([corr_set_val1, corr_set_val2])
    else:
        corr_set_val = correspondences.Correspondences(
            corr_set_config.correspondence_path,
            corr_set_config.correspondence_im_path,
            input_size=(713, 713),
            input_transform=input_transform,
            joint_transform=train_joint_transform_corr,
            listfile=corr_set_config.correspondence_val_list_file)

    corr_loader_val = DataLoader(corr_set_val,
                                 batch_size=1,
                                 num_workers=args['n_workers'],
                                 shuffle=False)

    # Loss setup
    val_corr_loss_fct_feat = FeatureLoss(
        input_size=[713, 713],
        loss_type=args['feature_distance_measure'],
        feat_dist_threshold_match=0.8,
        feat_dist_threshold_nomatch=0.2,
        n_not_matching=0)

    val_corr_loss_fct_out = FeatureLoss(input_size=[713, 713],
                                        loss_type='KL',
                                        feat_dist_threshold_match=0.8,
                                        feat_dist_threshold_nomatch=0.2,
                                        n_not_matching=0)

    loss_fct = ClusterCorrespondenceLoss(input_size=[713, 713],
                                         size_average=True).to(device)
    seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    if args['feature_hinge_loss_weight'] > 0:
        feature_loss_fct = FeatureLoss(input_size=[713, 713],
                                       loss_type='hingeF',
                                       feat_dist_threshold_match=0.8,
                                       feat_dist_threshold_nomatch=0.2,
                                       n_not_matching=0)

    # Validator
    corr_validator = CorrValidator(corr_loader_val,
                                   val_corr_loss_fct_feat,
                                   val_corr_loss_fct_out,
                                   loss_fct,
                                   save_snapshot=True,
                                   extra_name_str='Corr')

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    # Clustering
    deepcluster = clustering.Kmeans(args['n_clusters'])
    if skip_clustering:
        deepcluster.set_index(cluster_centroids)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    val_iter = 0
    curr_iter = start_iter
    while curr_iter <= args['max_iter']:
        if not skip_clustering:
            # Extract image features from reference images
            net.eval()
            net.output_features = True

            features, _ = extract_features_for_reference(
                net,
                model_config,
                ref_image_lists,
                corr_im_paths,
                ref_featurs_pos,
                max_num_features_per_image=args['max_features_per_image'],
                fraction_correspondeces=0.5)

            cluster_features = np.vstack(features)
            del features

            # cluster the features
            cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures(
                cluster_features, verbose=True, use_gpu=False)

            # save cluster centroids
            h5f = h5py.File(
                os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w')
            h5f.create_dataset('cluster_centroids', data=cluster_centroids)
            h5f.create_dataset('pca_transform_Amat', data=pca_info[0])
            h5f.create_dataset('pca_transform_bvec', data=pca_info[1])
            h5f.close()

            # Print distribution of clusters
            cluster_distribution, _ = np.histogram(
                cluster_indices,
                bins=np.arange(args['n_clusters'] + 1),
                density=True)
            str2write = 'cluster distribution ' + \
                np.array2string(cluster_distribution, formatter={'float_kind': '{0:.8f}'.format}).replace('\n', ' ')
            print(str2write)
            f_handle.write(str2write + "\n")

            reinit_last_layers(
                net)  # set last layer weight to a normal distribution

            # make a copy of current network state to do cluster assignment
            net_for_clustering = copy.deepcopy(net)
        else:
            skip_clustering = False

        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        net.train()
        freeze_bn(net)
        net.output_features = False
        cluster_training_count = 0

        # Train using the training correspondence set
        corr_train_loss = AverageMeter()
        seg_train_loss = AverageMeter()
        feature_train_loss = AverageMeter()

        while cluster_training_count < args[
                'cluster_interval'] and curr_iter <= args['max_iter']:

            # First extract cluster labels using saved network checkpoint
            net.to("cpu")
            net_for_clustering.to(device)
            net_for_clustering.eval()
            net_for_clustering.output_features = True
            if args['feature_hinge_loss_weight'] > 0:
                net_for_clustering.output_all = False
            data_samples = []
            extract_label_count = 0
            while (extract_label_count < args['chunk_size']) and (
                    cluster_training_count + extract_label_count <
                    args['cluster_interval']
            ) and (val_iter + extract_label_count < args['val_interval']) and (
                    extract_label_count + curr_iter <= args['max_iter']):
                img_ref, img_other, pts_ref, pts_other, _ = next(
                    iter(corr_loader_train))

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]

                with torch.no_grad():
                    features = net_for_clustering(img_ref)

                # assign feature to clusters for entire patch
                output = features.cpu().numpy()
                output_flat = output.reshape(
                    (output.shape[0], output.shape[1], -1))
                cluster_image = np.zeros(
                    (output.shape[0], output.shape[2], output.shape[3]),
                    dtype=np.int64)
                for b in range(output_flat.shape[0]):
                    out_f = output_flat[b]
                    out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1),
                                                    pca_info=pca_info)
                    cluster_labels = deepcluster.assign(out_f2)
                    cluster_image[b] = cluster_labels.reshape(
                        (output.shape[2], output.shape[3]))

                cluster_image = torch.from_numpy(cluster_image).to(device)

                # assign cluster to correspondence positions
                cluster_labels = assign_cluster_ids_to_correspondence_points(
                    features,
                    pts_ref, (deepcluster, pca_info),
                    inds_other=pts_other,
                    orig_im_size=(713, 713))

                # Transfer data to cpu
                img_ref = img_ref.cpu()
                img_other = img_other.cpu()
                pts_ref = [p.cpu() for p in pts_ref]
                pts_other = [p.cpu() for p in pts_other]
                cluster_labels = [p.cpu() for p in cluster_labels]
                cluster_image = cluster_image.cpu()
                data_samples.append((img_ref, img_other, pts_ref, pts_other,
                                     cluster_labels, cluster_image))
                extract_label_count += 1

            net_for_clustering.to("cpu")
            net.to(device)

            for data_sample in data_samples:
                img_ref, img_other, pts_ref, pts_other, cluster_labels, cluster_image = data_sample

                # Transfer data to device
                img_ref = img_ref.to(device)
                img_other = img_other.to(device)
                pts_ref = [p.to(device) for p in pts_ref]
                pts_other = [p.to(device) for p in pts_other]
                cluster_labels = [p.to(device) for p in cluster_labels]
                cluster_image = cluster_image.to(device)

                optimizer.zero_grad()

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = True

                # Randomization to decide if reference or target image should be used for training
                if args['fraction_reference_bp'] is None:  # use both
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                elif np.random.rand(
                        1)[0] < args['fraction_reference_bp']:  # use reference
                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                            img_ref)
                    else:
                        outputs_ref, aux_ref = net(img_ref)

                    seg_main_loss = seg_loss_fct(outputs_ref, cluster_image)
                    seg_aux_loss = seg_loss_fct(aux_ref, cluster_image)

                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                                img_other)
                        else:
                            outputs_other, aux_other = net(img_other)
                else:  # use target
                    with torch.no_grad():
                        if args['feature_hinge_loss_weight'] > 0:
                            out_feat_ref, aux_feat_ref, outputs_ref, aux_ref = net(
                                img_ref)
                        else:
                            outputs_ref, aux_ref = net(img_ref)

                    if args['feature_hinge_loss_weight'] > 0:
                        out_feat_other, aux_feat_other, outputs_other, aux_other = net(
                            img_other)
                    else:
                        outputs_other, aux_other = net(img_other)

                    seg_main_loss = 0.
                    seg_aux_loss = 0.

                if args['feature_hinge_loss_weight'] > 0:
                    net.output_all = False

                main_loss, _ = loss_fct(outputs_ref,
                                        outputs_other,
                                        None,
                                        pts_ref,
                                        pts_other,
                                        cluster_labels=cluster_labels)
                aux_loss, _ = loss_fct(aux_ref,
                                       aux_other,
                                       None,
                                       pts_ref,
                                       pts_other,
                                       cluster_labels=cluster_labels)

                if args['feature_hinge_loss_weight'] > 0:
                    feature_loss = feature_loss_fct(out_feat_ref,
                                                    out_feat_other, pts_ref,
                                                    pts_other, None)
                    feature_loss_aux = feature_loss_fct(
                        aux_feat_ref, aux_feat_other, pts_ref, pts_other, None)
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss) \
                        + args['feature_hinge_loss_weight']*(feature_loss + 0.4 * feature_loss_aux)
                else:
                    feature_loss = 0.
                    loss = args['corr_loss_weight']*(main_loss + 0.4 * aux_loss) + \
                        args['seg_loss_weight']*(seg_main_loss + 0.4 * seg_aux_loss)

                loss.backward()
                optimizer.step()
                cluster_training_count += 1

                corr_train_loss.update(main_loss.item(), 1)
                if type(seg_main_loss) == torch.Tensor:
                    seg_train_loss.update(seg_main_loss.item(), 1)
                if type(feature_loss) == torch.Tensor:
                    feature_train_loss.update(feature_loss.item(), 1)

                ####################################################################################################
                #       LOGGING ETC
                ####################################################################################################
                curr_iter += 1
                val_iter += 1

                writer.add_scalar('train_corr_loss', corr_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_seg_loss', seg_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('train_feature_loss', feature_train_loss.avg,
                                  curr_iter)
                writer.add_scalar('lr', optimizer.param_groups[1]['lr'],
                                  curr_iter)

                if (curr_iter + 1) % args['print_freq'] == 0:
                    str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % (
                        curr_iter + 1, args['max_iter'], seg_train_loss.avg,
                        corr_train_loss.avg, feature_train_loss.avg,
                        optimizer.param_groups[1]['lr'])

                    print(str2write)
                    f_handle.write(str2write + "\n")

                if val_iter >= args['val_interval']:
                    val_iter = 0
                    net_for_clustering.to(device)
                    corr_validator.run(net,
                                       net_for_clustering,
                                       (deepcluster, pca_info),
                                       optimizer,
                                       args,
                                       curr_iter,
                                       save_folder,
                                       f_handle,
                                       writer=writer)
                    net_for_clustering.to("cpu")

                if curr_iter > args['max_iter']:
                    break

    # Post training
    f_handle.close()
    writer.close()
Exemple #4
0
    def run(self, net, net_for_clustering, cluster_info, optimizer, args, curr_iter, save_dir, f_handle, writer=None):
        # the following code is written assuming that batch size is 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        net.eval()
        net_for_clustering.eval()
        ps_output_features1 = net_for_clustering.output_features
        ps_output_features2 = net.output_features
        net_for_clustering.output_features = True
        net.output_features = False

        feat_ext_for_cluster = FeatureExtractor(net_for_clustering, n_slices_per_pass=5)
        feat_ext = FeatureExtractor(net, n_slices_per_pass=5)

        confmat = np.zeros((self.n_clusters, self.n_clusters))
        for vi, data in enumerate(self.data_loader):
            img_slices, _, slices_info = data

            output = feat_ext_for_cluster.run_on_slices(img_slices, slices_info)
            seg_out = feat_ext.run_on_slices(img_slices, slices_info)
            seg = np.argmax(seg_out, 0)

            output_flat = output.reshape((output.shape[0], -1))
            out_f2, _ = preprocess_features(np.swapaxes(output_flat, 0, 1), pca_info=cluster_info[1])
            cluster_labels = cluster_info[0].assign(out_f2)
            cluster_image = cluster_labels.reshape((output.shape[1], output.shape[2]))

            acc, acc_cls, mean_iu, fwavacc, confmat, _ = evaluate_incremental(
                confmat, seg, cluster_image, self.n_clusters)

            if (vi % 100) == 0:
                str2write = 'validating: %d / %d' % (
                    vi + 1, len(self.data_loader))
                print(str2write)

        net_for_clustering.output_features = ps_output_features1
        net.output_features = ps_output_features2
        net.train()
        if 'freeze_bn' not in args or args['freeze_bn']:
            freeze_bn(net)

        if self.save_snapshot:
            snapshot_name = 'iter_%d_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
                curr_iter, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'])
            torch.save(net.state_dict(), os.path.join(save_dir, snapshot_name + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(save_dir, 'opt_' + snapshot_name + '.pth'))

            if args['best_record']['mean_iu'] < mean_iu:
                args['best_record']['iter'] = curr_iter
                args['best_record']['acc'] = acc
                args['best_record']['acc_cls'] = acc_cls
                args['best_record']['mean_iu'] = mean_iu
                args['best_record']['fwavacc'] = fwavacc
                args['best_record']['snapshot'] = snapshot_name
                open(os.path.join(save_dir, 'bestval.txt'), 'w').write(
                    str(args['best_record']) + '\n\n')

            str2write = '%s best record: [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (self.extra_name_str,
                                                                                                        args['best_record']['acc'], args['best_record']['acc_cls'], args['best_record']['mean_iu'], args['best_record']['fwavacc'])

            print(str2write)
            f_handle.write(str2write + "\n")

            str2write = '%s [iter %d], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (self.extra_name_str,
                                                                                                      curr_iter, acc, acc_cls, mean_iu, fwavacc)
            print(str2write)
            f_handle.write(str2write + "\n")

            if writer is not None:
                writer.add_scalar(self.extra_name_str + ': acc', acc, curr_iter)
                writer.add_scalar(self.extra_name_str +
                                  ': acc_cls', acc_cls, curr_iter)
                writer.add_scalar(self.extra_name_str +
                                  ': mean_iu', mean_iu, curr_iter)
                writer.add_scalar(self.extra_name_str +
                                  ': fwavacc', fwavacc, curr_iter)

        return mean_iu
Exemple #5
0
    def run(self, net, optimizer, args, curr_iter, save_dir, f_handle, writer=None):
        # the following code is written assuming that batch size is 1
        net.eval()
        segmentor = Segmentor(net, self.n_classes, colorize_fcn=None, n_slices_per_pass=10)

        confmat = np.zeros((self.n_classes, self.n_classes))
        for vi, data in enumerate(self.data_loader):
            img_slices, gt, slices_info = data
            gt.squeeze_(0)
            prediction_tmp = segmentor.run_on_slices(img_slices.squeeze_(0), slices_info.squeeze_(0))

            if prediction_tmp.shape != gt.size():
                prediction_tmp = Image.fromarray(prediction_tmp.astype(np.uint8)).convert('P')
                prediction_tmp = F.resize(prediction_tmp, gt.size(), interpolation=Image.NEAREST)

            acc, acc_cls, mean_iu, fwavacc, confmat, _ = evaluate_incremental(
                confmat, np.asarray(prediction_tmp), gt.numpy(), self.n_classes)

            str2write = 'validating: %d / %d' % (vi + 1, len(self.data_loader))
            print(str2write)
            # f_handle.write(str2write + "\n")

        # Store confusion matrix
        confmatdir = os.path.join(save_dir, 'confmat')
        os.makedirs(confmatdir, exist_ok=True)
        with open(os.path.join(confmatdir, self.extra_name_str + str(curr_iter) + '_confmat.pkl'), 'wb') as confmat_file:
            pickle.dump(confmat, confmat_file)

        if self.save_snapshot:
            snapshot_name = 'iter_%d_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
                curr_iter, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'])
            torch.save(net.state_dict(), os.path.join(
                save_dir, snapshot_name + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(
                save_dir, 'opt_' + snapshot_name + '.pth'))

            if args['best_record']['mean_iu'] < mean_iu:
                args['best_record']['iter'] = curr_iter
                args['best_record']['acc'] = acc
                args['best_record']['acc_cls'] = acc_cls
                args['best_record']['mean_iu'] = mean_iu
                args['best_record']['fwavacc'] = fwavacc
                args['best_record']['snapshot'] = snapshot_name
                open(os.path.join(save_dir, 'bestval.txt'), 'w').write(
                    str(args['best_record']) + '\n\n')

            str2write = '%s best record: [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (self.extra_name_str,
                                                                                                        args['best_record']['acc'], args['best_record']['acc_cls'], args['best_record']['mean_iu'], args['best_record']['fwavacc'])

            print(str2write)
            f_handle.write(str2write + "\n")

        str2write = '%s [iter %d], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (self.extra_name_str,
                                                                                                  curr_iter, acc, acc_cls, mean_iu, fwavacc)
        print(str2write)
        f_handle.write(str2write + "\n")

        if writer is not None:
            writer.add_scalar(self.extra_name_str + ': acc', acc, curr_iter)
            writer.add_scalar(self.extra_name_str +
                              ': acc_cls', acc_cls, curr_iter)
            writer.add_scalar(self.extra_name_str +
                              ': mean_iu', mean_iu, curr_iter)
            writer.add_scalar(self.extra_name_str +
                              ': fwavacc', fwavacc, curr_iter)

        net.train()
        if 'freeze_bn' not in args or args['freeze_bn']:
            freeze_bn(net)

        return mean_iu
Exemple #6
0
    def run(self, net, net_for_clustering, cluster_info, optimizer, args, curr_iter, save_dir, f_handle, writer=None):
        # the following code is written assuming that batch size is 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        net.eval()
        previous_setting_output_all = net.output_all
        net.output_all = True
        val_loss_feat = AverageMeter()
        val_loss_out = AverageMeter()
        val_loss_cluster = AverageMeter()
        for vi, data in enumerate(self.data_loader):
            img_ref, img_other, pts_ref, pts_other, weights = data

            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            with torch.no_grad():
                feats_ref, out_ref = net(img_ref)
                feats_other, out_other = net(img_other)
                features = net_for_clustering(img_ref)

            loss_feat = self.loss_fct_feat(
                feats_ref, feats_other, pts_ref, pts_other, weights)
            loss_out = self.loss_fct_out(
                out_ref, out_other, pts_ref, pts_other, weights)

            cluster_labels = assign_cluster_ids_to_correspondence_points(
                features, pts_ref, cluster_info, inds_other=pts_other, orig_im_size=self.loss_fct_cluster.input_size)
            loss_cluster, _ = self.loss_fct_cluster(
                out_ref, out_other, feats_ref, pts_ref, pts_other, cluster_labels=cluster_labels)

            val_loss_feat.update(loss_feat.item())
            val_loss_out.update(loss_out.item())
            val_loss_cluster.update(loss_cluster.item())

            if (vi % 100) == 0:
                str2write = 'validating: %d / %d' % (
                    vi + 1, len(self.data_loader))
                print(str2write)

        net.output_all = previous_setting_output_all
        net.train()
        if 'freeze_bn' not in args or args['freeze_bn']:
            freeze_bn(net)

        if self.save_snapshot:
            snapshot_name = 'iter_%d_lossfeat_%.5f_lossout_%.5f_lr_%.10f' % (
                curr_iter, val_loss_feat.avg, val_loss_out.avg, optimizer.param_groups[1]['lr'])
            torch.save(net.state_dict(), os.path.join(
                save_dir, snapshot_name + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(
                save_dir, 'opt_' + snapshot_name + '.pth'))

            if args['best_record']['val_loss_cluster'] > val_loss_cluster.avg:
                args['best_record']['val_loss_cluster'] = val_loss_cluster.avg
                args['best_record']['val_loss_feat'] = val_loss_feat.avg
                args['best_record']['val_loss_out'] = val_loss_out.avg
                args['best_record']['iter'] = curr_iter
                args['best_record']['snapshot'] = snapshot_name
                open(os.path.join(save_dir, 'bestval.txt'), 'w').write(
                    str(args['best_record']) + '\n\n')

            str2write = '%s best record: [val loss feat %.5f], [val loss out %.5f], [val loss cluster %.5f]' % (
                self.extra_name_str, args['best_record']['val_loss_feat'], args['best_record']['val_loss_out'], args['best_record']['val_loss_cluster'])

            print(str2write)
            f_handle.write(str2write + "\n")

        str2write = '%s [iter %d], [val loss feat %.5f], [val loss out %.5f], [val loss cluster %.5f]' % (
            self.extra_name_str, curr_iter, val_loss_feat.avg, val_loss_out.avg, val_loss_cluster.avg)
        print(str2write)
        f_handle.write(str2write + "\n")

        if writer is not None:
            writer.add_scalar(self.extra_name_str +
                              ': val_loss_feat', val_loss_feat.avg, curr_iter)
            writer.add_scalar(self.extra_name_str +
                              ': val_loss_out', val_loss_out.avg, curr_iter)
            writer.add_scalar(
                self.extra_name_str + ': val_loss_cluster', val_loss_cluster.avg, curr_iter)

        return val_loss_feat.avg