Ejemplo n.º 1
0
def prepare_images_list(data_dir, dump_path):
    classes = find_classes(data_dir)
    data_images_list = []
    for i, class_name in enumerate(classes):
        print('processing %d-th class: %s' % (i, class_name))
        temp = []
        class_dir = os.path.join(data_dir, class_name)
        filenames = os.listdir(class_dir)
        for filename in filenames:
            if is_image_file(filename):
                temp.append(os.path.join(class_dir, filename))

        data_images_list.append(temp)

    dump_pickle(data_images_list, dump_path)
Ejemplo n.º 2
0
          'Sparsity: %.4f, Accuracy: %.4f\n' %
          (np.mean(test_loss_ce), test_sparsity.item(), acc))
    return acc, test_sparsity


best_acc = 0
for epoch in range(args.epochs):
    train(epoch)
    acc, test_sparsity = test()
    if test_sparsity <= args.sparsity_level and acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(),
                   os.path.join(args.logdir, 'checkpoint.pth'))

        temp_params = []
        for i in range(len(gates_params)):
            temp_params.append(gates_params[i].data.clone().cpu())

        misc.dump_pickle(temp_params,
                         os.path.join(args.logdir, 'channel_gates.pkl'))

if best_acc == 0:
    torch.save(model.state_dict(), os.path.join(args.logdir, 'checkpoint.pth'))

    temp_params = []
    for i in range(len(gates_params)):
        temp_params.append(gates_params[i].data.clone().cpu())

    misc.dump_pickle(temp_params, os.path.join(args.logdir,
                                               'channel_gates.pkl'))
Ejemplo n.º 3
0
            continue
        else:
            pruned_cfg[i] = masks[counter].sum().long().item()
            counter += 1

    model = models.__dict__[args.arch](args.num_classes,
                                       args.expanded_inchannel, pruned_cfg)

    pruned_flops = calculate_flops(model)
    actual_pruned_ratio = 1 - pruned_flops / full_flops  # ratio of flops pruned
    print('Iter %d, start %.2f, end %.2f, pruned ratio = %.4f' %
          (j, start_pruned_ratio, end_pruned_ratio, actual_pruned_ratio))

    if abs(actual_pruned_ratio -
           args.pruned_ratio) / args.pruned_ratio <= args.eps:
        print(
            'Successfully reach the target pruned ratio with FLOPS = %.4f (M)'
            % (pruned_flops / 1e6))
        break

    if actual_pruned_ratio > args.pruned_ratio:
        end_pruned_ratio = cur_pruned_ratio
    else:
        start_pruned_ratio = cur_pruned_ratio

misc.dump_pickle(
    pruned_cfg,
    os.path.join(args.logdir, 'pruned_cfg-%.2f.pkl' % args.pruned_ratio))
misc.dump_pickle(
    masks, os.path.join(args.logdir, 'masks-%.2f.pkl' % args.pruned_ratio))
Ejemplo n.º 4
0
    # Using xz leads to this anyway, but it's worth reminding the reader.
    # To permute in 2D, use the --permute flag.
    return use,res;
if __name__ == "__main__":
    from docopt import docopt;
    from misc import readfile, mkvprint, dump_pickle;
    opts=docopt(__doc__,help=True);
    dims,res = handle_dims(opts);
    vprint = mkvprint;
    var = opts['<var>'];
    readvars = list(var);
    if readvars:
        readvars+=dims;
    if opts['--gen-samples']:
        xs = tuple([d[l] for l in dims]);
        i = simple_nearest_indices(xs,res);
        dump_pickle(opts["<output>"],(i,xs));
        exit(1);
    if opts['--sample']:
        i,xs = readfile(opts['--sample'], dumpfull=True);
    else:
        xs = tuple([d[l] for l in dims]);
        i = simple_nearest_indices(xs,res);
    did = {v:d[v][i] for v in var};
    #Has _D_ been _I_nterpolate_D_?  Yes it DID.
    did.update({l:x for l,x in zip(dims,xs)});
    #get it?
    #alright I'll stop
    dump_pickle(opts['<output>'], did);

Ejemplo n.º 5
0
parser = argparse.ArgumentParser(
    description='Extract the ILSVRC2012 val dataset')
parser.add_argument('--in_file',
                    default='val224_compressed.pkl',
                    help='input file path')
parser.add_argument('--out_root',
                    default='~/public_dataset/pytorch/imagenet-data/',
                    help='output file path')
args = parser.parse_args()

d = misc.load_pickle(args.in_file)
assert len(d['data']) == 50000, len(d['data'])
assert len(d['target']) == 50000, len(d['target'])

data224 = []
data299 = []
for img, target in tqdm.tqdm(zip(d['data'], d['target']), total=50000):
    img224 = misc.str2img(img)
    img299 = cv2.resize(img224, (299, 299))
    data224.append(img224)
    data299.append(img299)
data_dict224 = dict(data=np.array(data224).transpose(0, 3, 1, 2),
                    target=d['target'])
data_dict299 = dict(data=np.array(data299).transpose(0, 3, 1, 2),
                    target=d['target'])

if not os.path.exists(args.out_root):
    os.makedirs(args.out_root)
misc.dump_pickle(data_dict224, os.path.join(args.out_root, 'val224.pkl'))
misc.dump_pickle(data_dict299, os.path.join(args.out_root, 'val299.pkl'))
Ejemplo n.º 6
0
    val_adv_cg_list, _ = get_critical_path(adv_inp, model)

    val_all_orig_cglist.append(torch.cat(val_orig_cg_list))
    val_all_adv_cglist.append(torch.cat(val_adv_cg_list))
    if i % args.display_freq == 0:
        print('generate [%d/%d] val image...' % (i, len(val_loader)))

val_all_orig_cglist = torch.stack(val_all_orig_cglist)
val_all_adv_cglist = torch.stack(val_all_adv_cglist)

all_val_samples = torch.cat([val_all_orig_cglist,
                             val_all_adv_cglist]).cpu().numpy()
all_val_labels = np.hstack(
    [np.ones(len(val_all_orig_cglist)),
     np.zeros(len(val_all_adv_cglist))])

_idx = np.random.permutation(np.arange(len(all_val_labels)))
all_val_samples = all_val_samples[_idx]
all_val_labels = all_val_labels[_idx]

preds = clf.predict(all_val_samples)
prec = precision_score(all_val_labels, preds)
ras = roc_auc_score(all_val_labels, preds)

print('precision = %.4f, roc_auc_score = %.4f' % (prec, ras))
misc.dump_pickle([all_train_samples, all_train_labels],
                 os.path.join(args.logdir, 'train_infos.pkl'))
misc.dump_pickle([all_val_samples, all_val_labels],
                 os.path.join(args.logdir, 'val_infos.pkl'))
misc.dump_pickle(clf, os.path.join(args.logdir, 'clf.pkl'))