def validate_and_checkpoint():
     model.eval()
     val_loss, val_acc = AverageMeter(), AverageMeter()
     for input, target in pbar(val_loader):
         # Load data
         input_var, target_var = [d.cuda() for d in [input, target]]
         # Evaluate model
         with torch.no_grad():
             output = model(input_var)
             loss = criterion(output, target_var)
             _, pred = output.max(1)
             accuracy = (target_var.eq(pred)
                         ).data.float().sum().item() / input.size(0)
         val_loss.update(loss.data.item(), input.size(0))
         val_acc.update(accuracy, input.size(0))
         # Check accuracy
         pbar.post(l=val_loss.avg, a=val_acc.avg)
     # Save checkpoint
     save_checkpoint(
         {
             'iter': iter_num,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),
             'accuracy': val_acc.avg,
             'loss': val_loss.avg,
         }, val_acc.avg > best['val_accuracy'])
     best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
     printstat('Iteration %d val accuracy %.2f' %
               (iter_num, val_acc.avg * 100.0))
示例#2
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(
        description='Ablation eval',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        required=True,
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to edit' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes',
                        type=str,
                        nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric',
                        type=str,
                        default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--unitcount',
                        type=int,
                        default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter',
                        type=str,
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size',
                        type=int,
                        default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size',
                        type=int,
                        default=200,
                        help='number of images to test')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    pbar.verbose(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True)
    if model is None:
        print('No model specified')
        sys.exit(1)

    # Instantiate model
    device = next(model.parameters()).device
    input_shape = model.input_shape

    # 4d input if convolutional, 2d input if first layer is linear.
    raw_sample = standard_z_sample(args.size, input_shape[1],
                                   seed=2).view((args.size, ) +
                                                input_shape[1:])
    dataset = TensorDataset(raw_sample)

    # Create the segmenter
    segmenter = autoimport_eval(args.segmenter)

    # Now do the actual work.
    labelnames, catnames = (segmenter.get_label_and_category_names(dataset))
    label_category = [
        catnames.index(c) if c in catnames else 0 for l, c in labelnames
    ]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=args.batch_size,
                                            num_workers=10,
                                            pin_memory=(device.type == 'cuda'))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None

    # For each sort-order, do an ablation
    for classname in pbar(args.classes):
        pbar.post(c=classname)
        for layername in pbar(model.ablation):
            pbar.post(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                               if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'pixablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)):
                with open(os.path.join(ablationdir,
                                       '%s.json' % rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                           for a, o in zip(data.ablation_units, ordering)):
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue  # file already done.
                measurements = data.ablation_effects
            measurements = measure_ablation(segmenter, segloader, model,
                                            classnum, layername,
                                            ordering[:args.unitcount])
            measurements = measurements.cpu().numpy().tolist()
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json' % rankname),
                      'w') as f:
                json.dump(
                    dict(classname=classname,
                         classnum=classnum,
                         baseline=measurements[0],
                         layer=layername,
                         metric=args.metric,
                         ablation_units=ordering.tolist(),
                         ablation_effects=measurements[1:]), f)
def main():
    args = parseargs()
    experiment_dir = 'results/decoupled-%d-%s-resnet' % (args.selected_classes,
                                                         args.dataset)
    ds_dirname = dict(novelty='novelty/dataset_v1/known_classes/images',
                      imagenet='imagenet')[args.dataset]
    training_dir = 'datasets/%s/train' % ds_dirname
    val_dir = 'datasets/%s/val' % ds_dirname
    os.makedirs(experiment_dir, exist_ok=True)
    with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f:
        f.write(str(args) + '\n')

    def printstat(s):
        with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f:
            f.write(str(s) + '\n')
        pbar.print(s)

    def filter_tuple(item):
        return item[1] < args.selected_classes

    # Imagenet has a couple bad exif images.
    warnings.filterwarnings('ignore', message='.*orrupt EXIF.*')
    # Here's our data
    train_loader = torch.utils.data.DataLoader(
        parallelfolder.ParallelImageFolders(
            [training_dir],
            classification=True,
            filter_tuples=filter_tuple,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                renormalize.NORMALIZER['imagenet'],
            ])),
        batch_size=64,
        shuffle=True,
        num_workers=48,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        parallelfolder.ParallelImageFolders(
            [val_dir],
            classification=True,
            filter_tuples=filter_tuple,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                renormalize.NORMALIZER['imagenet'],
            ])),
        batch_size=64,
        shuffle=False,
        num_workers=24,
        pin_memory=True)
    late_model = torchvision.models.resnet50(num_classes=args.selected_classes)
    for n, p in late_model.named_parameters():
        if 'bias' in n:
            torch.nn.init.zeros_(p)
        elif len(p.shape) <= 1:
            torch.nn.init.ones_(p)
        else:
            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')
    late_model.train()
    late_model.cuda()

    model = late_model

    max_lr = 5e-3
    max_iter = args.training_iterations

    def criterion(logits, true_class):
        goal = torch.zeros_like(logits)
        goal.scatter_(1, true_class[:, None], value=1.0)
        return torch.nn.functional.binary_cross_entropy_with_logits(
            logits, goal)

    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr,
                                                    total_steps=max_iter - 1)
    iter_num = 0
    best = dict(val_accuracy=0.0)
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'weights.pth'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        val_loss, val_acc = AverageMeter(), AverageMeter()
        for input, target in pbar(val_loader):
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                loss = criterion(output, target_var)
                _, pred = output.max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            pbar.post(l=val_loss.avg, a=val_acc.avg)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'accuracy': val_acc.avg,
                'loss': val_loss.avg,
            }, val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        printstat('Iteration %d val accuracy %.2f' %
                  (iter_num, val_acc.avg * 100.0))

    # Here is our training loop.
    while iter_num < max_iter:
        for filtered_input, filtered_target in pbar(train_loader):
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            # Load data
            input_var, target_var = [
                d.cuda() for d in [filtered_input, filtered_target]
            ]
            # Evaluate model
            output = model(input_var)
            loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), filtered_input.size(0))
            # Perform one step of SGD
            if iter_num > 0:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # Learning rate schedule
                scheduler.step()
            # Also check training set accuracy
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                filtered_input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            pbar.post(l=train_loss.avg,
                      a=train_acc.avg,
                      v=best['val_accuracy'])
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
示例#4
0
def train_ablation(args,
                   corpus,
                   cachefile,
                   model,
                   segmenter,
                   classnum,
                   initial_ablation=None):
    cachedir = os.path.dirname(cachefile)
    snapdir = os.path.join(cachedir, 'snapshots')
    os.makedirs(snapdir, exist_ok=True)

    # high_replacement = corpus.feature_99[None,:,None,None].cuda()
    if '_h99' in args.variant:
        high_replacement = corpus.feature_99[None, :, None, None].cuda()
    elif '_tcm' in args.variant:
        # variant: top-conditional-mean
        high_replacement = (corpus.mean_present_feature[None, :, None,
                                                        None].cuda())
    else:  # default: weighted mean
        high_replacement = (corpus.weighted_mean_present_feature[None, :, None,
                                                                 None].cuda())
    fullimage_measurement = False
    ablation_only = False
    fullimage_ablation = False
    if '_fim' in args.variant:
        fullimage_measurement = True
    elif '_fia' in args.variant:
        fullimage_measurement = True
        ablation_only = True
        fullimage_ablation = True
    high_replacement.requires_grad = False
    for p in model.parameters():
        p.requires_grad = False

    ablation = torch.zeros(high_replacement.shape).cuda()
    if initial_ablation is not None:
        ablation.view(-1)[...] = initial_ablation
    ablation.requires_grad = True
    optimizer = torch.optim.Adam([ablation], lr=0.01)
    start_epoch = 0
    epoch = 0

    def eval_loss_and_reg():
        discrete_experiments = dict(
            # dpixel=dict(discrete_pixels=True),
            # dunits20=dict(discrete_units=20),
            # dumix20=dict(discrete_units=20, mixed_units=True),
            # dunits10=dict(discrete_units=10),
            # abonly=dict(ablation_only=True),
            # fimabl=dict(ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            dboth20=dict(discrete_units=20, discrete_pixels=True),
            # dbothm20=dict(discrete_units=20, mixed_units=True,
            #              discrete_pixels=True),
            # abdisc20=dict(discrete_units=20, discrete_pixels=True,
            #             ablation_only=True),
            # abdiscm20=dict(discrete_units=20, mixed_units=True,
            #             discrete_pixels=True,
            #             ablation_only=True),
            # fimadp=dict(discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadu10=dict(discrete_units=10,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadb10=dict(discrete_units=10, discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            fimadbm10=dict(discrete_units=10,
                           mixed_units=True,
                           discrete_pixels=True,
                           ablation_only=True,
                           fullimage_ablation=True,
                           fullimage_measurement=True),
            # fimadu20=dict(discrete_units=20,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            # fimadb20=dict(discrete_units=20, discrete_pixels=True,
            #             ablation_only=True,
            #             fullimage_ablation=True,
            #             fullimage_measurement=True),
            fimadbm20=dict(discrete_units=20,
                           mixed_units=True,
                           discrete_pixels=True,
                           ablation_only=True,
                           fullimage_ablation=True,
                           fullimage_measurement=True))
        with torch.no_grad():
            total_loss = 0
            discrete_losses = {k: 0 for k in discrete_experiments}
            for [pbatch, ploc, cbatch,
                 cloc] in pbar(torch.utils.data.DataLoader(
                     TensorDataset(corpus.eval_present_sample,
                                   corpus.eval_present_location,
                                   corpus.eval_candidate_sample,
                                   corpus.eval_candidate_location),
                     batch_size=args.inference_batch_size,
                     num_workers=10,
                     shuffle=False,
                     pin_memory=True),
                               desc="Eval"):
                # First, put in zeros for the selected units.
                # Loss is amount of remaining object.
                total_loss = total_loss + ace_loss(
                    segmenter,
                    classnum,
                    model,
                    args.layer,
                    high_replacement,
                    ablation,
                    pbatch,
                    ploc,
                    cbatch,
                    cloc,
                    run_backward=False,
                    ablation_only=ablation_only,
                    fullimage_measurement=fullimage_measurement)
                for k, config in discrete_experiments.items():
                    discrete_losses[k] = discrete_losses[k] + ace_loss(
                        segmenter,
                        classnum,
                        model,
                        args.layer,
                        high_replacement,
                        ablation,
                        pbatch,
                        ploc,
                        cbatch,
                        cloc,
                        run_backward=False,
                        **config)
            avg_loss = (total_loss / args.eval_size).item()
            avg_d_losses = {
                k: (d / args.eval_size).item()
                for k, d in discrete_losses.items()
            }
            regularizer = (args.l2_lambda * ablation.pow(2).sum())
            pbar.print('Epoch %d Loss %g Regularizer %g' %
                       (epoch, avg_loss, regularizer))
            pbar.print(' '.join('%s: %g' % (k, d)
                                for k, d in avg_d_losses.items()))
            pbar.print(scale_summary(ablation.view(-1), 10, 3))
            return avg_loss, regularizer, avg_d_losses

    if args.eval_only:
        # For eval_only, just load each snapshot and re-run validation eval
        # pass on each one.
        for epoch in range(-1, args.train_epochs):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch)
            if not os.path.exists(snapfile):
                data = {}
                if epoch >= 0:
                    print('No epoch %d' % epoch)
                    continue
            else:
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
            avg_loss, regularizer, new_extra = eval_loss_and_reg()
            # Keep old values, and update any new ones.
            extra = {
                k: v
                for k, v in data.items()
                if k not in ['ablation', 'optimizer', 'avg_loss']
            }
            extra.update(new_extra)
            torch.save(
                dict(ablation=ablation,
                     optimizer=optimizer.state_dict(),
                     avg_loss=avg_loss,
                     **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        # Return loaded ablation.
        return ablation.view(-1).detach().cpu().numpy()

    if not args.no_cache:
        for start_epoch in reversed(range(args.train_epochs)):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch)
            if os.path.exists(snapfile):
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
                start_epoch += 1
                break

    if start_epoch < args.train_epochs:
        epoch = start_epoch - 1
        avg_loss, regularizer, extra = eval_loss_and_reg()
        if epoch == -1:
            torch.save(
                dict(ablation=ablation,
                     optimizer=optimizer.state_dict(),
                     avg_loss=avg_loss,
                     **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))

    update_size = args.train_update_freq * args.train_batch_size
    for epoch in range(start_epoch, args.train_epochs):
        candidate_shuffle = torch.randperm(len(corpus.candidate_sample))
        train_loss = 0
        for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(
                pbar(torch.utils.data.DataLoader(
                    TensorDataset(
                        corpus.object_present_sample,
                        corpus.object_present_location,
                        corpus.candidate_sample[candidate_shuffle],
                        corpus.candidate_location[candidate_shuffle]),
                    batch_size=args.train_batch_size,
                    num_workers=10,
                    shuffle=True,
                    pin_memory=True),
                     desc="ACE opt epoch %d" % epoch)):
            if batch_num % args.train_update_freq == 0:
                optimizer.zero_grad()
            # First, put in zeros for the selected units.  Loss is amount
            # of remaining object.
            loss = ace_loss(segmenter,
                            classnum,
                            model,
                            args.layer,
                            high_replacement,
                            ablation,
                            pbatch,
                            ploc,
                            cbatch,
                            cloc,
                            run_backward=True,
                            ablation_only=ablation_only,
                            fullimage_measurement=fullimage_measurement)
            with torch.no_grad():
                train_loss = train_loss + loss
            if (batch_num + 1) % args.train_update_freq == 0:
                # Third, add some L2 loss to encourage sparsity.
                regularizer = (args.l2_lambda * update_size *
                               ablation.pow(2).sum())
                regularizer.backward()
                optimizer.step()
                with torch.no_grad():
                    ablation.clamp_(0, 1)
                    pbar.post(l=(train_loss / update_size).item(),
                              r=(regularizer / update_size).item())
                    train_loss = 0

        avg_loss, regularizer, extra = eval_loss_and_reg()
        torch.save(
            dict(ablation=ablation,
                 optimizer=optimizer.state_dict(),
                 avg_loss=avg_loss,
                 **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch),
                   ablation.detach().cpu().numpy())

    # The output of this phase is this set of scores.
    return ablation.view(-1).detach().cpu().numpy()