def tally_topk(compute,
               dataset,
               sample_size=None,
               batch_size=10,
               k=100,
               cachefile=None,
               **kwargs):
    '''
    Computes the topk statistics for a large data sample that can be
    computed from a dataset.  The compute function should return one
    batch of samples as a (sample, unit)-dimension tensor.

    k specifies the number of top samples to retain.
    Results are returned as a RunningTopK object.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size, k=k)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningTopK(state=cached_state)
        rtk = runningstats.RunningTopK(k=k)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        for batch in pbar(loader):
            sample = call_compute(compute, batch)
            rtk.add(sample)
        rtk.to_('cpu')
        save_cached_state(cachefile, rtk, args)
        return rtk
def tally_bincount(compute,
                   dataset,
                   sample_size=None,
                   batch_size=10,
                   multi_label_axis=None,
                   cachefile=None,
                   **kwargs):
    '''
    Computes bincount totals for a large data sample that can be
    computed from a dataset.  The compute function should return one
    batch of samples as a (sample, unit)-dimension tensor.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningBincount(state=cached_state)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        rbc = runningstats.RunningBincount()
        for batch in pbar(loader):
            sample = call_compute(compute, batch)
            if multi_label_axis:
                multilabel = sample.shape[multi_label_axis]
                size = sample.numel() // multilabel
            else:
                size = None
            rbc.add(sample, size=size)
        rbc.to_('cpu')
        save_cached_state(cachefile, rbc, args)
        return rbc
Example #3
0
def compute_mean_present_features(args, corpus, cache_filename, model):
    # Phase 1.5.  Figure mean activations for every channel where there
    # is a doorway.
    if all(k in corpus for k in ['mean_present_feature']):
        return
    with torch.no_grad():
        total_present_feature = 0
        for [zbatch, featloc] in pbar(torch.utils.data.DataLoader(
                TensorDataset(corpus.object_present_sample,
                              corpus.object_present_location),
                batch_size=args.inference_batch_size,
                num_workers=10,
                pin_memory=True),
                                      desc="Mean activations"):
            zbatch = zbatch.cuda()
            featloc = featloc.cuda()
            tensor_image = model(zbatch)
            feat = model.retained_layer(args.layer)
            flatfeat = feat.view(feat.shape[0], feat.shape[1], -1)
            sum_feature_at_obj = flatfeat[
                torch.arange(feat.shape[0]).to(feat.device), :, featloc].sum(0)
            total_present_feature = total_present_feature + sum_feature_at_obj
        corpus.mean_present_feature = (
            total_present_feature / len(corpus.object_present_sample)).cpu()
    if cache_filename:
        numpy.savez(cache_filename, **corpus)
def tally_directory(directory, size=10000, seed=1):
    dataset = parallelfolder.ParallelImageFolders(
        [directory],
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))
    loader = DataLoader(
        dataset,
        sampler=FixedRandomSubsetSampler(dataset, end=size, seed=1),
        # sampler=FixedSubsetSampler(range(size)),
        batch_size=10,
        pin_memory=True)
    upp = segmenter.UnifiedParsingSegmenter()
    labelnames, catnames = upp.get_label_and_category_names()
    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
    batch_result = torch.zeros(loader.batch_size,
                               NUM_OBJECTS,
                               dtype=torch.float).cuda()
    with torch.no_grad():
        batch_index = 0
        for [batch] in pbar(loader):
            seg_result = upp.segment_batch(batch.cuda())
            for i in range(len(batch)):
                batch_result[i] = (seg_result[i, 0].view(-1).bincount(
                    minlength=NUM_OBJECTS).float() /
                                   (seg_result.shape[2] * seg_result.shape[3]))
            result[batch_index:batch_index +
                   len(batch)] = (batch_result.cpu().numpy())
            batch_index += len(batch)
    return result
def tally_conditional_mean(compute,
                           dataset,
                           sample_size=None,
                           batch_size=1,
                           cachefile=None,
                           **kwargs):
    '''
    Computes conditional mean and variance for a large data sample that
    can be computed from a dataset.  The compute function should return a
    sequence of sample batch tuples (condition, (sample, unit)-tensor),
    one for each condition relevant to the batch.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningConditionalVariance(state=cached_state)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        cv = runningstats.RunningConditionalVariance()
        for i, batch in enumerate(pbar(loader)):
            sample_set = call_compute(compute, batch)
            for cond, sample in sample_set:
                # Move uncommon conditional data to the cpu before collating.
                cv.add(cond, sample)
        # At the end, move all to the CPU
        cv.to_('cpu')
        save_cached_state(cachefile, cv, args)
        return cv
Example #6
0
def compute_feature_quantiles(args, corpus, cache_filename, model,
                              full_sample):
    # Phase 1.6.  Figure the 99% and 99.9%ile of every feature.
    if all(k in corpus for k in ['feature_99', 'feature_999']):
        return
    with torch.no_grad():
        rq = RunningQuantile(r=5000)  # 10x what's needed.
        for [zbatch] in pbar(torch.utils.data.DataLoader(
                TensorDataset(full_sample),
                batch_size=args.inference_batch_size,
                num_workers=10,
                pin_memory=True),
                             desc="Calculating 0.999 quantile"):
            zbatch = zbatch.cuda()
            tensor_image = model(zbatch)
            feat = model.retained_layer(args.layer)
            rq.add(
                feat.permute(0, 2, 3, 1).contiguous().view(-1, feat.shape[1]))
        result = rq.quantiles([0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999])
        corpus.feature_001 = result[:, 0].cpu()
        corpus.feature_01 = result[:, 1].cpu()
        corpus.feature_10 = result[:, 2].cpu()
        corpus.feature_50 = result[:, 3].cpu()
        corpus.feature_90 = result[:, 4].cpu()
        corpus.feature_99 = result[:, 5].cpu()
        corpus.feature_999 = result[:, 6].cpu()
    numpy.savez(cache_filename, **corpus)
 def validate_and_checkpoint():
     model.eval()
     val_loss, val_acc = AverageMeter(), AverageMeter()
     for input, target in pbar(val_loader):
         # Load data
         input_var, target_var = [d.cuda() for d in [input, target]]
         # Evaluate model
         with torch.no_grad():
             output = model(input_var)
             loss = criterion(output, target_var)
             _, pred = output.max(1)
             accuracy = (target_var.eq(pred)
                         ).data.float().sum().item() / input.size(0)
         val_loss.update(loss.data.item(), input.size(0))
         val_acc.update(accuracy, input.size(0))
         # Check accuracy
         pbar.post(l=val_loss.avg, a=val_acc.avg)
     # Save checkpoint
     save_checkpoint(
         {
             'iter': iter_num,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),
             'accuracy': val_acc.avg,
             'loss': val_loss.avg,
         }, val_acc.avg > best['val_accuracy'])
     best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
     printstat('Iteration %d val accuracy %.2f' %
               (iter_num, val_acc.avg * 100.0))
def tally_quantile(compute,
                   dataset,
                   sample_size=None,
                   batch_size=10,
                   r=4096,
                   cachefile=None,
                   **kwargs):
    '''
    Computes quantile sketch statistics for a large data sample that can
    be computed from a dataset.  The compute function should return one
    batch of samples as a (sample, unit)-dimension tensor.

    The underlying quantile sketch is an optimal KLL sorted sampler that
    retains at least r samples (where r is the specified resolution).
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size, r=r)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningQuantile(state=cached_state)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        rq = runningstats.RunningQuantile()
        for batch in pbar(loader):
            sample = call_compute(compute, batch)
            rq.add(sample)
        rq.to_('cpu')
        save_cached_state(cachefile, rq, args)
        return rq
Example #9
0
def my_test_perclass(model, dataset, layername=None, ablated_units=None, sample_size=None, cachefile=None):
    model.remove_edits()
    if ablated_units is not None:
        def ablate_the_units(x, *args):
            x[:, ablated_units] = 0
            return x
        model.edit_layer(layername, rule=ablate_the_units)
    # sampler = None if sample_size is None else (
    # FixedSubsetSampler(list(range(sample_size))))
    with torch.no_grad():
        num_classes = 101

        loader = torch.utils.data.DataLoader(
            dataset, batch_size=16, num_workers=6,  # num_workers=20, batch_size=100
            pin_memory=True)
        test_running_correct = 0
        total_count = 0
        test_correct_per_class = [0] * num_classes
        total_per_class = [0] * num_classes
        acc_per_class = [0] * num_classes

        for batch_idx, (image_batch, class_batch) in enumerate(pbar(loader)):
            total_count += len(image_batch)
            image_batch, class_batch = [d.cuda() for d in [image_batch, class_batch]]
            # image batch has shape of [bs,test_frame_size,c,l,w]

            batch_size, test_frame_size, channels, l, w = image_batch.shape
            images = torch.reshape(image_batch, (batch_size * test_frame_size, channels, l, -1))
            # images has shape of [bs*test_frame_size, c,l,w]
            scores = model(images)
            scores = torch.reshape(scores, (batch_size, test_frame_size, -1))
            # scores has shape of [bs,test_frame_size,101]
            scores = scores.mean(dim=1)
            #scores has shape of [bs,101]

            # loss = criterion(outputs, labels)
            # test_running_loss += loss.item()
            _, preds = scores.max(1)
            correct = (preds == class_batch)
            test_running_correct += correct.sum().item()
            torch.cuda.empty_cache()

            for idx, (label, pred) in enumerate(zip(class_batch, preds)):
                if(label.item() == pred.item()):
                    test_correct_per_class[label.item()] += 1
                total_per_class[label.item()] += 1
            torch.cuda.empty_cache()

    for idx, (corr, total) in enumerate(zip(test_correct_per_class, total_per_class)):
        acc_per_class[idx] = round(corr/total,4)

    print("Predicted correctly %d" % (test_running_correct))
    print("Total data points: %d" % (total_count))
    accuracy = round(test_running_correct / total_count,4)
    print('Test Acc: %.3f percent' % (100 * accuracy))
        
    return accuracy, np.asarray(acc_per_class)
Example #10
0
def measure_ablation(segmenter, loader, model, classnum, layer, ordering):
    total_bincount = 0
    data_size = 0
    device = next(model.parameters()).device
    for l in model.ablation:
        model.ablation[l] = None
    feature_units = model.feature_shape[layer][1]
    feature_shape = model.feature_shape[layer][2:]
    repeats = len(ordering)
    total_scores = torch.zeros(repeats + 1)
    for i, batch in enumerate(pbar(loader)):
        z_batch = batch[0]
        model.ablation[layer] = None
        tensor_images = model(z_batch.to(device))
        seg = segmenter.segment_batch(tensor_images, downsample=2)
        mask = (seg == classnum).max(1)[0]
        downsampled_seg = torch.nn.functional.adaptive_avg_pool2d(
            mask.float()[:, None, :, :], feature_shape)[:, 0, :, :]
        total_scores[0] += downsampled_seg.sum().cpu()
        # Now we need to do an intervention for every location
        # that had a nonzero downsampled_seg, if any.
        interventions_needed = downsampled_seg.nonzero()
        location_count = len(interventions_needed)
        if location_count == 0:
            continue
        interventions_needed = interventions_needed.repeat(repeats, 1)
        inter_z = batch[0][interventions_needed[:, 0]].to(device)
        inter_chan = torch.zeros(repeats,
                                 location_count,
                                 feature_units,
                                 device=device)
        for j, u in enumerate(ordering):
            inter_chan[j:, :, u] = 1
        inter_chan = inter_chan.view(len(inter_z), feature_units)
        inter_loc = interventions_needed[:, 1:]
        scores = torch.zeros(len(inter_z))
        batch_size = len(batch[0])
        for j in range(0, len(inter_z), batch_size):
            ibz = inter_z[j:j + batch_size]
            ibl = inter_loc[j:j + batch_size].t()
            imask = torch.zeros((len(ibz), ) + feature_shape,
                                device=ibz.device)
            imask[(torch.arange(len(ibz)), ) + tuple(ibl)] = 1
            ibc = inter_chan[j:j + batch_size]
            model.ablation[layer] = (imask.float()[:, None, :, :] *
                                     ibc[:, :, None, None])
            tensor_images = model(ibz)
            seg = segmenter.segment_batch(tensor_images, downsample=2)
            mask = (seg == classnum).max(1)[0]
            downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d(
                mask.float()[:, None, :, :], feature_shape)[:, 0, :, :]
            scores[j:j +
                   batch_size] = downsampled_iseg[(torch.arange(len(ibz)), ) +
                                                  tuple(ibl)]
        scores = scores.view(repeats, location_count).sum(1)
        total_scores[1:] += scores
    return total_scores
Example #11
0
def compute_present_locations(args, corpus, cache_filename, model, segmenter,
                              classnum, full_sample):
    # Phase 1.  Identify a set of locations where there are doorways.
    # Segment the image and find featuremap pixels that maximize the number
    # of doorway pixels under the featuremap pixel.
    if all(k in corpus for k in [
            'present_indices', 'object_present_sample',
            'object_present_location', 'object_location_popularity',
            'weighted_mean_present_feature'
    ]):
        return
    feature_shape = model.feature_shape[args.layer][2:]
    num_locations = numpy.prod(feature_shape).item()
    num_units = model.feature_shape[args.layer][1]
    with torch.no_grad():
        weighted_feature_sum = torch.zeros(num_units).cuda()
        object_presence_scores = []
        for [zbatch] in pbar(torch.utils.data.DataLoader(
                TensorDataset(full_sample),
                batch_size=args.inference_batch_size,
                num_workers=10,
                pin_memory=True),
                             desc="Object pool"):
            zbatch = zbatch.cuda()
            tensor_image = model(zbatch)
            segmented_image = segmenter.segment_batch(tensor_image,
                                                      downsample=2)
            mask = (segmented_image == classnum).max(1)[0]
            score = torch.nn.functional.adaptive_avg_pool2d(
                mask.float(), feature_shape)
            object_presence_scores.append(score.cpu())
            feat = model.retained_layer(args.layer)
            weighted_feature_sum += (feat * score[:, None, :, :]).view(
                feat.shape[0], feat.shape[1], -1).sum(2).sum(0)
        object_presence_at_feature = torch.cat(object_presence_scores)
        object_presence_at_image, object_location_in_image = (
            object_presence_at_feature.view(args.search_size, -1).max(1))
        best_presence_scores, best_presence_images = torch.sort(
            -object_presence_at_image)
        all_present_indices = torch.sort(
            best_presence_images[:(args.train_size + args.eval_size)])[0]
        corpus.present_indices = all_present_indices[:args.train_size]
        corpus.object_present_sample = full_sample[corpus.present_indices]
        corpus.object_present_location = object_location_in_image[
            corpus.present_indices]
        corpus.object_location_popularity = torch.bincount(
            corpus.object_present_location, minlength=num_locations)
        corpus.weighted_mean_present_feature = (
            weighted_feature_sum.cpu() /
            (1e-20 + object_presence_at_feature.view(-1).sum()))
        corpus.eval_present_indices = all_present_indices[-args.eval_size:]
        corpus.eval_present_sample = full_sample[corpus.eval_present_indices]
        corpus.eval_present_location = object_location_in_image[
            corpus.eval_present_indices]

    if cache_filename:
        numpy.savez(cache_filename, **corpus)
def tally_cat(compute, dataset, sample_size=None, batch_size=10, **kwargs):
    '''
    Computes a concatenated tensor for data batches that can be
    computed from a dataset.  The compute function should return
    a tensor that should be concatenated to the others along its
    first dimension.
    '''
    with torch.no_grad():
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        result = []
        for batch in pbar(loader):
            result.append(call_compute(compute, batch).cpu())
        return torch.cat(result)
Example #13
0
def walk_image_files(rootdir, verbose=None):
    print("Walking image files ... ")
    indexfile = '%s.txt' % rootdir
    if os.path.isfile(indexfile):
        print("from index file: ", indexfile)
        basedir = os.path.dirname(rootdir)
        with open(indexfile) as f:
            result = sorted([
                os.path.join(basedir, line.strip()) for line in f.readlines()
            ])
            return result
    result = []
    for dirname, _, fnames in sorted(
            pbar(os.walk(rootdir),
                 desc='Walking %s' % os.path.basename(rootdir))):
        for fname in sorted(fnames):
            if is_image_file(fname) or is_npy_file(fname):
                result.append(os.path.join(dirname, fname))
    return result
Example #14
0
def visualize_training_locations(args, corpus, cachedir, model):
    # Phase 2.5 Create visualizations of the corpus images.
    feature_shape = model.feature_shape[args.layer][2:]
    num_locations = numpy.prod(feature_shape).item()
    with torch.no_grad():
        imagedir = os.path.join(cachedir, 'image')
        os.makedirs(imagedir, exist_ok=True)
        image_saver = WorkerPool(SaveImageWorker)
        for group, group_sample, group_location, group_indices in [
            ('present', corpus.object_present_sample,
             corpus.object_present_location, corpus.present_indices),
            ('candidate', corpus.candidate_sample, corpus.candidate_location,
             corpus.candidate_indices)
        ]:
            for [zbatch, featloc,
                 indices] in pbar(torch.utils.data.DataLoader(
                     TensorDataset(group_sample, group_location,
                                   group_indices),
                     batch_size=args.inference_batch_size,
                     num_workers=10,
                     pin_memory=True),
                                  desc="Visualize %s" % group):
                zbatch = zbatch.cuda()
                tensor_image = model(zbatch)
                feature_mask = torch.zeros((len(zbatch), 1) + feature_shape)
                feature_mask.view(len(zbatch),
                                  -1).scatter_(1, featloc[:, None], 1)
                feature_mask = torch.nn.functional.adaptive_max_pool2d(
                    feature_mask.float(), tensor_image.shape[-2:]).cuda()
                yellow = torch.Tensor([1.0, 1.0, -1.0])[None, :, None,
                                                        None].cuda()
                tensor_image = tensor_image * (1 - 0.5 * feature_mask) + (
                    0.5 * feature_mask * yellow)
                byte_image = (((tensor_image + 1) / 2) * 255).clamp(
                    0, 255).byte()
                numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy()
                for i, index in enumerate(indices):
                    image_saver.add(
                        numpy_image[i],
                        os.path.join(imagedir, '%s_%d.jpg' % (group, index)))
    image_saver.join()
def tally_generated_objects(model, size=10000):
    zds = zdataset.z_dataset_for_model(model, size)
    loader = DataLoader(zds, batch_size=10, pin_memory=True)
    upp = segmenter.UnifiedParsingSegmenter()
    labelnames, catnames = upp.get_label_and_category_names()
    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
    batch_result = torch.zeros(loader.batch_size,
                               NUM_OBJECTS,
                               dtype=torch.float).cuda()
    with torch.no_grad():
        batch_index = 0
        for [zbatch] in pbar(loader):
            img = model(zbatch.cuda())
            seg_result = upp.segment_batch(img)
            for i in range(len(zbatch)):
                batch_result[i] = (seg_result[i, 0].view(-1).bincount(
                    minlength=NUM_OBJECTS).float() /
                                   (seg_result.shape[2] * seg_result.shape[3]))
            result[batch_index:batch_index +
                   len(zbatch)] = (batch_result.cpu().numpy())
            batch_index += len(zbatch)
    return result
def tally_conditional_quantile(compute,
                               dataset,
                               sample_size=None,
                               batch_size=1,
                               gpu_cache=64,
                               r=1024,
                               cachefile=None,
                               **kwargs):
    '''
    Computes conditional quantile sketches for a large data sample that
    can be computed from a dataset.  The compute function should return a
    sequence of sample batch tuples (condition, (sample, unit)-tensor),
    one for each condition relevant to the batch.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size, r=r)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningConditionalQuantile(state=cached_state)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        cq = runningstats.RunningConditionalQuantile(r=r)
        most_common_conditions = set()
        for i, batch in enumerate(pbar(loader)):
            sample_set = call_compute(compute, batch)
            for cond, sample in sample_set:
                # Move uncommon conditional data to the cpu before collating.
                if cond not in most_common_conditions:
                    sample = sample.cpu()
                cq.add(cond, sample)
            # Move uncommon conditions off the GPU.
            if i and not i & (i - 1):  # if i is a power of 2:
                common_conditions = set(cq.most_common_conditions(gpu_cache))
                cq.to_('cpu',
                       [k for k in cq.keys() if k not in common_conditions])
        # At the end, move all to the CPU
        cq.to_('cpu')
        save_cached_state(cachefile, cq, args)
        return cq
def tally_dataset_objects(dataset, size=10000):
    loader = DataLoader(dataset,
                        sampler=FixedRandomSubsetSampler(dataset, end=size),
                        batch_size=10,
                        pin_memory=True)
    upp = segmenter.UnifiedParsingSegmenter()
    labelnames, catnames = upp.get_label_and_category_names()
    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
    batch_result = torch.zeros(loader.batch_size,
                               NUM_OBJECTS,
                               dtype=torch.float).cuda()
    with torch.no_grad():
        batch_index = 0
        for [batch] in pbar(loader):
            seg_result = upp.segment_batch(batch.cuda())
            for i in range(len(batch)):
                batch_result[i] = (seg_result[i, 0].view(-1).bincount(
                    minlength=NUM_OBJECTS).float() /
                                   (seg_result.shape[2] * seg_result.shape[3]))
            result[batch_index:batch_index +
                   len(batch)] = (batch_result.cpu().numpy())
            batch_index += len(batch)
    return result
def tally_mean(compute,
               dataset,
               sample_size=None,
               batch_size=10,
               cachefile=None,
               **kwargs):
    '''
    Computes unitwise mean and variance stats for a large data sample that
    can be computed from a dataset.  The compute function should return one
    batch of samples as a (sample, unit)-dimension tensor.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size)
        cached_state = load_cached_state(cachefile, args)
        if cached_state is not None:
            return runningstats.RunningVariance(state=cached_state)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        rv = runningstats.RunningVariance()
        for batch in pbar(loader):
            sample = call_compute(compute, batch)
            rv.add(sample)
        rv.to_('cpu')
        save_cached_state(cachefile, rv, args)
        return rv
Example #19
0
 def eval_loss_and_reg():
     discrete_experiments = dict(
         # dpixel=dict(discrete_pixels=True),
         # dunits20=dict(discrete_units=20),
         # dumix20=dict(discrete_units=20, mixed_units=True),
         # dunits10=dict(discrete_units=10),
         # abonly=dict(ablation_only=True),
         # fimabl=dict(ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         dboth20=dict(discrete_units=20, discrete_pixels=True),
         # dbothm20=dict(discrete_units=20, mixed_units=True,
         #              discrete_pixels=True),
         # abdisc20=dict(discrete_units=20, discrete_pixels=True,
         #             ablation_only=True),
         # abdiscm20=dict(discrete_units=20, mixed_units=True,
         #             discrete_pixels=True,
         #             ablation_only=True),
         # fimadp=dict(discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadu10=dict(discrete_units=10,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadb10=dict(discrete_units=10, discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         fimadbm10=dict(discrete_units=10,
                        mixed_units=True,
                        discrete_pixels=True,
                        ablation_only=True,
                        fullimage_ablation=True,
                        fullimage_measurement=True),
         # fimadu20=dict(discrete_units=20,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadb20=dict(discrete_units=20, discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         fimadbm20=dict(discrete_units=20,
                        mixed_units=True,
                        discrete_pixels=True,
                        ablation_only=True,
                        fullimage_ablation=True,
                        fullimage_measurement=True))
     with torch.no_grad():
         total_loss = 0
         discrete_losses = {k: 0 for k in discrete_experiments}
         for [pbatch, ploc, cbatch,
              cloc] in pbar(torch.utils.data.DataLoader(
                  TensorDataset(corpus.eval_present_sample,
                                corpus.eval_present_location,
                                corpus.eval_candidate_sample,
                                corpus.eval_candidate_location),
                  batch_size=args.inference_batch_size,
                  num_workers=10,
                  shuffle=False,
                  pin_memory=True),
                            desc="Eval"):
             # First, put in zeros for the selected units.
             # Loss is amount of remaining object.
             total_loss = total_loss + ace_loss(
                 segmenter,
                 classnum,
                 model,
                 args.layer,
                 high_replacement,
                 ablation,
                 pbatch,
                 ploc,
                 cbatch,
                 cloc,
                 run_backward=False,
                 ablation_only=ablation_only,
                 fullimage_measurement=fullimage_measurement)
             for k, config in discrete_experiments.items():
                 discrete_losses[k] = discrete_losses[k] + ace_loss(
                     segmenter,
                     classnum,
                     model,
                     args.layer,
                     high_replacement,
                     ablation,
                     pbatch,
                     ploc,
                     cbatch,
                     cloc,
                     run_backward=False,
                     **config)
         avg_loss = (total_loss / args.eval_size).item()
         avg_d_losses = {
             k: (d / args.eval_size).item()
             for k, d in discrete_losses.items()
         }
         regularizer = (args.l2_lambda * ablation.pow(2).sum())
         pbar.print('Epoch %d Loss %g Regularizer %g' %
                    (epoch, avg_loss, regularizer))
         pbar.print(' '.join('%s: %g' % (k, d)
                             for k, d in avg_d_losses.items()))
         pbar.print(scale_summary(ablation.view(-1), 10, 3))
         return avg_loss, regularizer, avg_d_losses
Example #20
0
def compute_candidate_locations(args, corpus, cache_filename, model, segmenter,
                                classnum, second_sample):
    # Phase 2.  Identify a set of candidate locations for doorways.
    # Place the median doorway activation in every location of an image
    # and identify where it can go that doorway pixels increase.
    if all(k in corpus for k in [
            'candidate_indices', 'candidate_sample', 'candidate_score',
            'candidate_location', 'object_score_at_candidate',
            'candidate_location_popularity'
    ]):
        return
    feature_shape = model.feature_shape[args.layer][2:]
    num_locations = numpy.prod(feature_shape).item()
    with torch.no_grad():
        # Simplify - just treat all locations as possible
        possible_locations = numpy.arange(num_locations)

        # Speed up search for locations, by weighting probed locations
        # according to observed distribution.
        location_weights = (corpus.object_location_popularity).double()
        location_weights += (location_weights.mean()) / 10.0
        location_weights = location_weights / location_weights.sum()

        candidate_scores = []
        object_scores = []
        prng = numpy.random.RandomState(1)
        for [zbatch] in pbar(torch.utils.data.DataLoader(
                TensorDataset(second_sample),
                batch_size=args.inference_batch_size,
                num_workers=10,
                pin_memory=True),
                             desc="Candidate pool"):
            batch_scores = torch.zeros((len(zbatch), ) + feature_shape).cuda()
            flat_batch_scores = batch_scores.view(len(zbatch), -1)
            zbatch = zbatch.cuda()
            tensor_image = model(zbatch)
            segmented_image = segmenter.segment_batch(tensor_image,
                                                      downsample=2)
            mask = (segmented_image == classnum).max(1)[0]
            object_score = torch.nn.functional.adaptive_avg_pool2d(
                mask.float(), feature_shape)
            baseline_presence = mask.float().view(mask.shape[0], -1).sum(1)

            edit_mask = torch.zeros((1, 1) + feature_shape).cuda()
            if '_tcm' in args.variant:
                # variant: top-conditional-mean
                replace_vec = (corpus.mean_present_feature[None, :, None,
                                                           None].cuda())
            else:  # default: weighted mean
                replace_vec = (
                    corpus.weighted_mean_present_feature[None, :, None,
                                                         None].cuda())
            # Sample 10 random locations to examine.
            for loc in prng.choice(possible_locations,
                                   replace=False,
                                   p=location_weights,
                                   size=5):
                edit_mask.zero_()
                edit_mask.view(-1)[loc] = 1
                model.edit_layer(args.layer,
                                 ablation=edit_mask,
                                 replacement=replace_vec)
                tensor_image = model(zbatch)
                segmented_image = segmenter.segment_batch(tensor_image,
                                                          downsample=2)
                mask = (segmented_image == classnum).max(1)[0]
                modified_presence = mask.float().view(mask.shape[0], -1).sum(1)
                flat_batch_scores[:, loc] = (modified_presence -
                                             baseline_presence)
            candidate_scores.append(batch_scores.cpu())
            object_scores.append(object_score.cpu())

        object_scores = torch.cat(object_scores)
        candidate_scores = torch.cat(candidate_scores)
        # Eliminate candidates where the object is present.
        candidate_scores = candidate_scores * (object_scores == 0).float()
        candidate_score_at_image, candidate_location_in_image = (
            candidate_scores.view(args.search_size, -1).max(1))
        best_candidate_scores, best_candidate_images = torch.sort(
            -candidate_score_at_image)
        all_candidate_indices = torch.sort(
            best_candidate_images[:(args.train_size + args.eval_size)])[0]
        corpus.candidate_indices = all_candidate_indices[:args.train_size]
        corpus.candidate_sample = second_sample[corpus.candidate_indices]
        corpus.candidate_location = candidate_location_in_image[
            corpus.candidate_indices]
        corpus.candidate_score = candidate_score_at_image[
            corpus.candidate_indices]
        corpus.object_score_at_candidate = object_scores.view(
            len(object_scores), -1)[corpus.candidate_indices,
                                    corpus.candidate_location]
        corpus.candidate_location_popularity = torch.bincount(
            corpus.candidate_location, minlength=num_locations)
        corpus.eval_candidate_indices = all_candidate_indices[-args.eval_size:]
        corpus.eval_candidate_sample = second_sample[
            corpus.eval_candidate_indices]
        corpus.eval_candidate_location = candidate_location_in_image[
            corpus.eval_candidate_indices]
    numpy.savez(cache_filename, **corpus)
Example #21
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(
        description='Ablation eval',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        required=True,
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to edit' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes',
                        type=str,
                        nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric',
                        type=str,
                        default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--unitcount',
                        type=int,
                        default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter',
                        type=str,
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size',
                        type=int,
                        default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size',
                        type=int,
                        default=200,
                        help='number of images to test')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    pbar.verbose(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True)
    if model is None:
        print('No model specified')
        sys.exit(1)

    # Instantiate model
    device = next(model.parameters()).device
    input_shape = model.input_shape

    # 4d input if convolutional, 2d input if first layer is linear.
    raw_sample = standard_z_sample(args.size, input_shape[1],
                                   seed=2).view((args.size, ) +
                                                input_shape[1:])
    dataset = TensorDataset(raw_sample)

    # Create the segmenter
    segmenter = autoimport_eval(args.segmenter)

    # Now do the actual work.
    labelnames, catnames = (segmenter.get_label_and_category_names(dataset))
    label_category = [
        catnames.index(c) if c in catnames else 0 for l, c in labelnames
    ]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=args.batch_size,
                                            num_workers=10,
                                            pin_memory=(device.type == 'cuda'))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None

    # For each sort-order, do an ablation
    for classname in pbar(args.classes):
        pbar.post(c=classname)
        for layername in pbar(model.ablation):
            pbar.post(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                               if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'pixablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)):
                with open(os.path.join(ablationdir,
                                       '%s.json' % rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                           for a, o in zip(data.ablation_units, ordering)):
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue  # file already done.
                measurements = data.ablation_effects
            measurements = measure_ablation(segmenter, segloader, model,
                                            classnum, layername,
                                            ordering[:args.unitcount])
            measurements = measurements.cpu().numpy().tolist()
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json' % rankname),
                      'w') as f:
                json.dump(
                    dict(classname=classname,
                         classnum=classnum,
                         baseline=measurements[0],
                         layer=layername,
                         metric=args.metric,
                         ablation_units=ordering.tolist(),
                         ablation_effects=measurements[1:]), f)
Example #22
0
def main():
    args = parseargs()
    resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg,
                                         args.layer, int(args.quantile * 1000))

    def resfile(f):
        return os.path.join(resdir, f)

    model = load_model(args)
    layername = instrumented_layername(args)
    model.retain_layer(layername)
    dataset = load_dataset(args, model=model.model)
    upfn = make_upfn(args, dataset, model, layername)
    sample_size = len(dataset)
    is_generator = (args.model == 'progan')
    percent_level = 1.0 - args.quantile

    # Tally rq.np (representation quantile, unconditional).
    torch.set_grad_enabled(False)
    pbar.descnext('rq')

    def compute_samples(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    rq = tally.tally_quantile(compute_samples,
                              dataset,
                              sample_size=sample_size,
                              r=8192,
                              num_workers=100,
                              pin_memory=True,
                              cachefile=resfile('rq.npz'))

    # Grab the 99th percentile, and tally conditional means at that level.
    level_at_99 = rq.quantiles(percent_level).cuda()[None, :, None, None]

    segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg)
    renorm = renormalize.renormalizer(dataset, target='zc')

    def compute_conditional_indicator(batch, *args):
        data_batch = batch.cuda()
        out_batch = model(data_batch)
        image_batch = out_batch if is_generator else renorm(data_batch)
        seg = segmodel.segment_batch(image_batch, downsample=4)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts > level_at_99).float()  # indicator
        return tally.conditional_samples(iacts, seg)

    pbar.descnext('condi99')
    condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
                                           dataset,
                                           sample_size=sample_size,
                                           num_workers=3,
                                           pin_memory=True,
                                           cachefile=resfile('condi99.npz'))

    # Now summarize the iou stats and graph the units
    iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
    unit_label_99 = [(concept.item(), seglabels[concept],
                      segcatlabels[concept], bestiou.item())
                     for (bestiou, concept) in zip(*iou_99.max(0))]

    def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100):
        model.remove_edits()

        def zero_some_units(x, *args):
            x[:, zeroed_units] = 0
            return x

        model.edit_layer(layername, rule=zero_some_units)
        num_seglabels = len(segmodel.get_label_and_category_names()[0])

        def compute_mean_seg_in_images(batch_z, *args):
            img = model(batch_z.cuda())
            seg = segmodel.segment_batch(img, downsample=4)
            seg_area = seg.shape[2] * seg.shape[3]
            seg_counts = torch.bincount(
                (seg + (num_seglabels * torch.arange(
                    seg.shape[0], dtype=seg.dtype,
                    device=seg.device)[:, None, None, None])).view(-1),
                minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1)
            seg_fracs = seg_counts.float() / seg_area
            return seg_fracs

        result = tally.tally_mean(compute_mean_seg_in_images,
                                  dataset,
                                  batch_size=30,
                                  sample_size=sample_size,
                                  pin_memory=True)
        model.remove_edits()
        return result

    # Intervention experiment here:
    # segs_baseline = measure_segclasses_with_zeroed_units([])
    # segs_without_treeunits = measure_segclasses_with_zeroed_units(tree_units)
    num_units = len(unit_label_99)
    baseline_segmean = test_generator_segclass_stats(
        model,
        dataset,
        segmodel,
        layername=layername,
        cachefile=resfile('segstats/baseline.npz')).mean()

    pbar.descnext('unit ablation')
    unit_ablation_segmean = torch.zeros(num_units, len(baseline_segmean))
    for unit in pbar(random.sample(range(num_units), num_units)):
        stats = test_generator_segclass_stats(
            model,
            dataset,
            segmodel,
            layername=layername,
            zeroed_units=[unit],
            cachefile=resfile('segstats/ablated_unit_%d.npz' % unit))
        unit_ablation_segmean[unit] = stats.mean()

    ablate_segclass_name = 'tree'
    ablate_segclass = seglabels.index(ablate_segclass_name)
    best_iou_units = iou_99[ablate_segclass, :].sort(0)[1].flip(0)
    byiou_unit_ablation_seg = torch.zeros(30)
    for unitcount in pbar(random.sample(range(0, 30), 30)):
        zero_units = best_iou_units[:unitcount].tolist()
        stats = test_generator_segclass_delta_stats(
            model,
            dataset,
            segmodel,
            layername=layername,
            zeroed_units=zero_units,
            cachefile=resfile('deltasegstats/ablated_best_%d_iou_%s.npz' %
                              (unitcount, ablate_segclass_name)))
        byiou_unit_ablation_seg[unitcount] = stats.mean()[ablate_segclass]

    # Generator context experiment.
    num_segclass = len(seglabels)
    door_segclass = seglabels.index('door')
    door_units = iou_99[door_segclass].sort(0)[1].flip(0)[:20]
    door_high_values = rq.quantiles(0.995)[door_units].cuda()

    def compute_seg_impact(zbatch, *args):
        zbatch = zbatch.cuda()
        model.remove_edits()
        orig_img = model(zbatch)
        orig_seg = segmodel.segment_batch(orig_img, downsample=4)
        orig_segcount = tally.batch_bincount(orig_seg, num_segclass)
        rep = model.retained_layer(layername).clone()
        ysize = orig_seg.shape[2] // rep.shape[2]
        xsize = orig_seg.shape[3] // rep.shape[3]

        def gen_conditions():
            for y in range(rep.shape[2]):
                for x in range(rep.shape[3]):
                    # Take as the context location the segmentation
                    # labels at the center of the square.
                    selsegs = orig_seg[:, :, y * ysize + ysize // 2,
                                       x * xsize + xsize // 2]
                    changed_rep = rep.clone()
                    changed_rep[:, door_units, y,
                                x] = (door_high_values[None, :])
                    model.edit_layer(layername,
                                     ablation=1.0,
                                     replacement=changed_rep)
                    changed_img = model(zbatch)
                    changed_seg = segmodel.segment_batch(changed_img,
                                                         downsample=4)
                    changed_segcount = tally.batch_bincount(
                        changed_seg, num_segclass)
                    delta_segcount = (changed_segcount - orig_segcount).float()
                    for sel, delta in zip(selsegs, delta_segcount):
                        for cond in torch.bincount(sel).nonzero()[:, 0]:
                            if cond == 0:
                                continue
                            yield (cond.item(), delta)

        return gen_conditions()

    cond_changes = tally.tally_conditional_mean(
        compute_seg_impact,
        dataset,
        sample_size=10000,
        batch_size=20,
        cachefile=resfile('big_door_cond_changes.npz'))
def main():
    args = parseargs()
    experiment_dir = 'results/decoupled-%d-%s-resnet' % (args.selected_classes,
                                                         args.dataset)
    ds_dirname = dict(novelty='novelty/dataset_v1/known_classes/images',
                      imagenet='imagenet')[args.dataset]
    training_dir = 'datasets/%s/train' % ds_dirname
    val_dir = 'datasets/%s/val' % ds_dirname
    os.makedirs(experiment_dir, exist_ok=True)
    with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f:
        f.write(str(args) + '\n')

    def printstat(s):
        with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f:
            f.write(str(s) + '\n')
        pbar.print(s)

    def filter_tuple(item):
        return item[1] < args.selected_classes

    # Imagenet has a couple bad exif images.
    warnings.filterwarnings('ignore', message='.*orrupt EXIF.*')
    # Here's our data
    train_loader = torch.utils.data.DataLoader(
        parallelfolder.ParallelImageFolders(
            [training_dir],
            classification=True,
            filter_tuples=filter_tuple,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                renormalize.NORMALIZER['imagenet'],
            ])),
        batch_size=64,
        shuffle=True,
        num_workers=48,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        parallelfolder.ParallelImageFolders(
            [val_dir],
            classification=True,
            filter_tuples=filter_tuple,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                renormalize.NORMALIZER['imagenet'],
            ])),
        batch_size=64,
        shuffle=False,
        num_workers=24,
        pin_memory=True)
    late_model = torchvision.models.resnet50(num_classes=args.selected_classes)
    for n, p in late_model.named_parameters():
        if 'bias' in n:
            torch.nn.init.zeros_(p)
        elif len(p.shape) <= 1:
            torch.nn.init.ones_(p)
        else:
            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')
    late_model.train()
    late_model.cuda()

    model = late_model

    max_lr = 5e-3
    max_iter = args.training_iterations

    def criterion(logits, true_class):
        goal = torch.zeros_like(logits)
        goal.scatter_(1, true_class[:, None], value=1.0)
        return torch.nn.functional.binary_cross_entropy_with_logits(
            logits, goal)

    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr,
                                                    total_steps=max_iter - 1)
    iter_num = 0
    best = dict(val_accuracy=0.0)
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'weights.pth'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        val_loss, val_acc = AverageMeter(), AverageMeter()
        for input, target in pbar(val_loader):
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                loss = criterion(output, target_var)
                _, pred = output.max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            pbar.post(l=val_loss.avg, a=val_acc.avg)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'accuracy': val_acc.avg,
                'loss': val_loss.avg,
            }, val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        printstat('Iteration %d val accuracy %.2f' %
                  (iter_num, val_acc.avg * 100.0))

    # Here is our training loop.
    while iter_num < max_iter:
        for filtered_input, filtered_target in pbar(train_loader):
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            # Load data
            input_var, target_var = [
                d.cuda() for d in [filtered_input, filtered_target]
            ]
            # Evaluate model
            output = model(input_var)
            loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), filtered_input.size(0))
            # Perform one step of SGD
            if iter_num > 0:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # Learning rate schedule
                scheduler.step()
            # Also check training set accuracy
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                filtered_input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            pbar.post(l=train_loss.avg,
                      a=train_acc.avg,
                      v=best['val_accuracy'])
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
Example #24
0
def train_ablation(args,
                   corpus,
                   cachefile,
                   model,
                   segmenter,
                   classnum,
                   initial_ablation=None):
    cachedir = os.path.dirname(cachefile)
    snapdir = os.path.join(cachedir, 'snapshots')
    os.makedirs(snapdir, exist_ok=True)

    # high_replacement = corpus.feature_99[None,:,None,None].cuda()
    if '_h99' in args.variant:
        high_replacement = corpus.feature_99[None, :, None, None].cuda()
    elif '_tcm' in args.variant:
        # variant: top-conditional-mean
        high_replacement = (corpus.mean_present_feature[None, :, None,
                                                        None].cuda())
    else:  # default: weighted mean
        high_replacement = (corpus.weighted_mean_present_feature[None, :, None,
                                                                 None].cuda())
    fullimage_measurement = False
    ablation_only = False
    fullimage_ablation = False
    if '_fim' in args.variant:
        fullimage_measurement = True
    elif '_fia' in args.variant:
        fullimage_measurement = True
        ablation_only = True
        fullimage_ablation = True
    high_replacement.requires_grad = False
    for p in model.parameters():
        p.requires_grad = False

    ablation = torch.zeros(high_replacement.shape).cuda()
    if initial_ablation is not None:
        ablation.view(-1)[...] = initial_ablation
    ablation.requires_grad = True
    optimizer = torch.optim.Adam([ablation], lr=0.01)
    start_epoch = 0
    epoch = 0

    def eval_loss_and_reg():
        discrete_experiments = dict(
            # dpixel=dict(discrete_pixels=True),
            # dunits20=dict(discrete_units=20),
            # dumix20=dict(discrete_units=20, mixed_units=True),
            # dunits10=dict(discrete_units=10),
            # abonly=dict(ablation_only=True),
            # fimabl=dict(ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            dboth20=dict(discrete_units=20, discrete_pixels=True),
            # dbothm20=dict(discrete_units=20, mixed_units=True,
            #              discrete_pixels=True),
            # abdisc20=dict(discrete_units=20, discrete_pixels=True,
            #             ablation_only=True),
            # abdiscm20=dict(discrete_units=20, mixed_units=True,
            #             discrete_pixels=True,
            #             ablation_only=True),
            # fimadp=dict(discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadu10=dict(discrete_units=10,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadb10=dict(discrete_units=10, discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            fimadbm10=dict(discrete_units=10,
                           mixed_units=True,
                           discrete_pixels=True,
                           ablation_only=True,
                           fullimage_ablation=True,
                           fullimage_measurement=True),
            # fimadu20=dict(discrete_units=20,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadb20=dict(discrete_units=20, discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            fimadbm20=dict(discrete_units=20,
                           mixed_units=True,
                           discrete_pixels=True,
                           ablation_only=True,
                           fullimage_ablation=True,
                           fullimage_measurement=True))
        with torch.no_grad():
            total_loss = 0
            discrete_losses = {k: 0 for k in discrete_experiments}
            for [pbatch, ploc, cbatch,
                 cloc] in pbar(torch.utils.data.DataLoader(
                     TensorDataset(corpus.eval_present_sample,
                                   corpus.eval_present_location,
                                   corpus.eval_candidate_sample,
                                   corpus.eval_candidate_location),
                     batch_size=args.inference_batch_size,
                     num_workers=10,
                     shuffle=False,
                     pin_memory=True),
                               desc="Eval"):
                # First, put in zeros for the selected units.
                # Loss is amount of remaining object.
                total_loss = total_loss + ace_loss(
                    segmenter,
                    classnum,
                    model,
                    args.layer,
                    high_replacement,
                    ablation,
                    pbatch,
                    ploc,
                    cbatch,
                    cloc,
                    run_backward=False,
                    ablation_only=ablation_only,
                    fullimage_measurement=fullimage_measurement)
                for k, config in discrete_experiments.items():
                    discrete_losses[k] = discrete_losses[k] + ace_loss(
                        segmenter,
                        classnum,
                        model,
                        args.layer,
                        high_replacement,
                        ablation,
                        pbatch,
                        ploc,
                        cbatch,
                        cloc,
                        run_backward=False,
                        **config)
            avg_loss = (total_loss / args.eval_size).item()
            avg_d_losses = {
                k: (d / args.eval_size).item()
                for k, d in discrete_losses.items()
            }
            regularizer = (args.l2_lambda * ablation.pow(2).sum())
            pbar.print('Epoch %d Loss %g Regularizer %g' %
                       (epoch, avg_loss, regularizer))
            pbar.print(' '.join('%s: %g' % (k, d)
                                for k, d in avg_d_losses.items()))
            pbar.print(scale_summary(ablation.view(-1), 10, 3))
            return avg_loss, regularizer, avg_d_losses

    if args.eval_only:
        # For eval_only, just load each snapshot and re-run validation eval
        # pass on each one.
        for epoch in range(-1, args.train_epochs):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch)
            if not os.path.exists(snapfile):
                data = {}
                if epoch >= 0:
                    print('No epoch %d' % epoch)
                    continue
            else:
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
            avg_loss, regularizer, new_extra = eval_loss_and_reg()
            # Keep old values, and update any new ones.
            extra = {
                k: v
                for k, v in data.items()
                if k not in ['ablation', 'optimizer', 'avg_loss']
            }
            extra.update(new_extra)
            torch.save(
                dict(ablation=ablation,
                     optimizer=optimizer.state_dict(),
                     avg_loss=avg_loss,
                     **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        # Return loaded ablation.
        return ablation.view(-1).detach().cpu().numpy()

    if not args.no_cache:
        for start_epoch in reversed(range(args.train_epochs)):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch)
            if os.path.exists(snapfile):
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
                start_epoch += 1
                break

    if start_epoch < args.train_epochs:
        epoch = start_epoch - 1
        avg_loss, regularizer, extra = eval_loss_and_reg()
        if epoch == -1:
            torch.save(
                dict(ablation=ablation,
                     optimizer=optimizer.state_dict(),
                     avg_loss=avg_loss,
                     **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))

    update_size = args.train_update_freq * args.train_batch_size
    for epoch in range(start_epoch, args.train_epochs):
        candidate_shuffle = torch.randperm(len(corpus.candidate_sample))
        train_loss = 0
        for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(
                pbar(torch.utils.data.DataLoader(
                    TensorDataset(
                        corpus.object_present_sample,
                        corpus.object_present_location,
                        corpus.candidate_sample[candidate_shuffle],
                        corpus.candidate_location[candidate_shuffle]),
                    batch_size=args.train_batch_size,
                    num_workers=10,
                    shuffle=True,
                    pin_memory=True),
                     desc="ACE opt epoch %d" % epoch)):
            if batch_num % args.train_update_freq == 0:
                optimizer.zero_grad()
            # First, put in zeros for the selected units.  Loss is amount
            # of remaining object.
            loss = ace_loss(segmenter,
                            classnum,
                            model,
                            args.layer,
                            high_replacement,
                            ablation,
                            pbatch,
                            ploc,
                            cbatch,
                            cloc,
                            run_backward=True,
                            ablation_only=ablation_only,
                            fullimage_measurement=fullimage_measurement)
            with torch.no_grad():
                train_loss = train_loss + loss
            if (batch_num + 1) % args.train_update_freq == 0:
                # Third, add some L2 loss to encourage sparsity.
                regularizer = (args.l2_lambda * update_size *
                               ablation.pow(2).sum())
                regularizer.backward()
                optimizer.step()
                with torch.no_grad():
                    ablation.clamp_(0, 1)
                    pbar.post(l=(train_loss / update_size).item(),
                              r=(regularizer / update_size).item())
                    train_loss = 0

        avg_loss, regularizer, extra = eval_loss_and_reg()
        torch.save(
            dict(ablation=ablation,
                 optimizer=optimizer.state_dict(),
                 avg_loss=avg_loss,
                 **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch),
                   ablation.detach().cpu().numpy())

    # The output of this phase is this set of scores.
    return ablation.view(-1).detach().cpu().numpy()
Example #25
0
def main():
    args = parseargs()

    model = setting.load_classifier(args.model)
    model = nethook.InstrumentedModel(model).cuda().eval()
    layername = args.layer
    model.retain_layer(layername)
    dataset = setting.load_dataset(args.dataset, crop_size=224)
    train_dataset = setting.load_dataset(args.dataset,
                                         crop_size=224,
                                         split='train')
    sample_size = len(dataset)

    # Probe layer to get sizes
    model(dataset[0][0][None].cuda())
    num_units = model.retained_layer(layername).shape[1]
    classlabels = dataset.classes

    # Measure baseline classification accuracy on val set, and cache.
    pbar.descnext('baseline_pra')
    baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = (
        test_perclass_pra(model,
                          dataset,
                          cachefile=sharedfile('pra-%s-%s/pra_baseline.npz' %
                                               (args.model, args.dataset))))
    pbar.print('baseline acc', baseline_ba.mean().item())

    # Now erase each unit, one at a time, and retest accuracy.
    unit_list = random.sample(list(range(num_units)), num_units)
    val_single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
    for unit in pbar(unit_list):
        pbar.descnext('test unit %d' % unit)
        # Get binary accuracy if the model after ablating the unit.
        _, _, _, ablation_ba = test_perclass_pra(
            model,
            dataset,
            layername=layername,
            ablated_units=[unit],
            cachefile=sharedfile('pra-%s-%s/pra_ablate_unit_%d.npz' %
                                 (args.model, args.dataset, unit)))
        val_single_unit_ablation_ba[unit] = ablation_ba

    # For the purpose of ranking units by importance to a class, we
    # measure using the training set (to avoid training unit ordering
    # on the test set).
    sample_size = None
    # Measure baseline classification accuracy, and cache.
    pbar.descnext('train_baseline_pra')
    baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = (
        test_perclass_pra(
            model,
            train_dataset,
            sample_size=sample_size,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_train_baseline.npz' %
                                 (args.model, args.dataset))))
    pbar.print('baseline acc', baseline_ba.mean().item())

    # Measure accuracy on the val set.
    pbar.descnext('val_baseline_pra')
    _, _, _, val_baseline_ba = (test_perclass_pra(
        model,
        dataset,
        cachefile=sharedfile('ttv-pra-%s-%s/pra_val_baseline.npz' %
                             (args.model, args.dataset))))
    pbar.print('val baseline acc', val_baseline_ba.mean().item())

    # Do in shuffled order to allow multiprocessing.
    single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
    for unit in pbar(unit_list):
        pbar.descnext('test unit %d' % unit)
        _, _, _, ablation_ba = test_perclass_pra(
            model,
            train_dataset,
            layername=layername,
            ablated_units=[unit],
            sample_size=sample_size,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_train_ablate_unit_%d.npz' %
                                 (args.model, args.dataset, unit)))
        single_unit_ablation_ba[unit] = ablation_ba

    # Now for every class, remove a set of the N most-important
    # and N least-important units for that class, and measure accuracy.
    for classnum in pbar(
            random.sample(range(len(classlabels)), len(classlabels))):
        # For a few classes, let's chart the whole range of ablations.
        if classnum in [100, 169, 351, 304]:
            num_best_list = range(1, num_units)
        else:
            num_best_list = [1, 2, 3, 4, 5, 20, 64, 128, 256]
        pbar.descnext('numbest')
        for num_best in pbar(random.sample(num_best_list, len(num_best_list))):
            num_worst = num_units - num_best
            unitlist = single_unit_ablation_ba[:,
                                               classnum].sort(0)[1][:num_best]
            _, _, _, testba = test_perclass_pra(
                model,
                dataset,
                layername=layername,
                ablated_units=unitlist,
                cachefile=sharedfile(
                    'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' %
                    (args.model, args.dataset, classlabels[classnum],
                     len(unitlist))))
            unitlist = (
                single_unit_ablation_ba[:, classnum].sort(0)[1][-num_worst:])
            _, _, _, testba2 = test_perclass_pra(
                model,
                dataset,
                layername=layername,
                ablated_units=unitlist,
                cachefile=sharedfile(
                    'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz'
                    % (args.model, args.dataset, classlabels[classnum],
                       len(unitlist))))
            pbar.print('%s: best %d %.3f vs worst N %.3f' %
                       (classlabels[classnum], num_best,
                        testba[classnum] - val_baseline_ba[classnum],
                        testba2[classnum] - val_baseline_ba[classnum]))
Example #26
0
def test_perclass_pra(model,
                      dataset,
                      layername=None,
                      ablated_units=None,
                      sample_size=None,
                      cachefile=None):
    '''Classifier precision/recall/accuracy measurement.
    Disables a set of units in the specified layer, and then
    measures per-class precision, recall, accuracy and
    balanced (binary classification) accuracy for each class,
    compared to the ground truth in the given dataset.'''
    try:
        if cachefile is not None:
            data = numpy.load(cachefile)
            # verify that this is computed.
            data['true_negative_rate']
            result = tuple(
                torch.tensor(data[key]) for key in
                ['precision', 'recall', 'accuracy', 'balanced_accuracy'])
            pbar.print('Loading cached %s' % cachefile)
            return result
    except:
        pass
    model.remove_edits()
    if ablated_units is not None:

        def ablate_the_units(x, *args):
            x[:, ablated_units] = 0
            return x

        model.edit_layer(layername, rule=ablate_the_units)
    with torch.no_grad():
        num_classes = len(dataset.classes)
        true_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        pred_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        correct_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        total_count = 0
        sampler = None if sample_size is None else (FixedSubsetSampler(
            list(range(sample_size))))
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=100,
                                             num_workers=20,
                                             sampler=sampler,
                                             pin_memory=True)
        for image_batch, class_batch in pbar(loader):
            total_count += len(image_batch)
            image_batch, class_batch = [
                d.cuda() for d in [image_batch, class_batch]
            ]
            scores = model(image_batch)
            preds = scores.max(1)[1]
            correct = (preds == class_batch)
            true_counts.add_(class_batch.bincount(minlength=num_classes))
            pred_counts.add_(preds.bincount(minlength=num_classes))
            correct_counts.add_(
                class_batch.bincount(correct, minlength=num_classes).long())
    model.remove_edits()
    true_neg_counts = ((total_count - true_counts) -
                       (pred_counts - correct_counts))
    precision = (correct_counts.float() / pred_counts.float()).cpu()
    recall = (correct_counts.float() / true_counts.float()).cpu()
    accuracy = (correct_counts + true_neg_counts).float().cpu() / total_count
    true_neg_rate = (true_neg_counts.float() /
                     (total_count - true_counts).float()).cpu()
    balanced_accuracy = (recall + true_neg_rate) / 2
    if cachefile is not None:
        numpy.savez(cachefile,
                    precision=precision.numpy(),
                    recall=recall.numpy(),
                    accuracy=accuracy.numpy(),
                    true_negative_rate=true_neg_rate.numpy(),
                    balanced_accuracy=balanced_accuracy.numpy())
    return precision, recall, accuracy, balanced_accuracy