Пример #1
0
def visualize(args, config):
    if config.log:
        mkfilelogger('ecGAN', config.sub('log'),
                     logging.DEBUG if config.get('debug') else logging.INFO)

    if args.seed:
        random.seed(args.seed)

    Visualizer = visualizers[config.visualizer.type]

    with Visualizer(config, *config.visualizer.get('args', []),
                    **config.visualizer.get('kwargs', {})) as vs:
        for i in range(config.explanation.iterations):
            vs.feed()
Пример #2
0
def main():
    parser = ArgumentParser()

    parser.add_argument('command', choices=commands.keys())
    parser.add_argument('-f', '--config', action='append', default=[])
    parser.add_argument('-c', '--chain', action='append', default=[])
    parser.add_argument('-u', '--update', action='append', default=[])
    parser.add_argument('--epoch_range', type=int, nargs=3)
    parser.add_argument('--iter', type=int, default=1)
    parser.add_argument('--seed', type=int, default=0xDEADBEEF)
    parser.add_argument('--classnum', type=int, default=10)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--mkdirs', action='store_true')
    parser.add_argument('--crash_chain', action='store_true')
    parser.add_argument('-k', '--pskip', action='append', type=int, default=[])
    parser.add_argument(
        '-t',
        '--tag',
        type=str,
        default='',
        help=
        '-t \'(tag1|!tag2)&tag3|tag4\' := ((tag1 or not tag2) and tag3) or tag4'
    )

    args = parser.parse_args(sys.argv[1:])

    if len(args.config):
        config = Config()

        for cpath in args.config:
            config.update_from_file(cpath)

        for ustr in args.update:
            config.update(yaml.safe_load(ustr))

        net_module = load_module_file(config.sub('net_file'), 'net_module')

        if args.mkdirs:
            mkdirs(None, lconf)
    else:
        config = None

    if args.seed:
        random.seed(args.seed)

    if args.debug:
        import ipdb
        ipdb.set_trace()

    commands[args.command](args, config)
Пример #3
0
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet import random
from mxnet import ndarray as nd
import random
import numpy as np

import matplotlib as mlp

mlp.rcParams['figure.dpi'] = 120
import matplotlib.pyplot as plt

mx.random.seed(1)
random.seed(1)


# Adam: 定义了一阶矩 + 二阶矩进行自适应学习率
def adam(params, lr, vals, sqrs, iter, batch_size, beta1=0.9, beta2=0.999):
    eps_stable = 1e-8
    for param, val, sqr in zip(params, vals, sqrs):
        g = param.grad / batch_size
        val[:] = beta1 * val + (1 - beta1) * g
        sqr[:] = beta2 * sqr + (1 - beta2) * nd.square(g)
        #val_next = val / (1 - nd.power(beta1, iter))
        val_next = val / (1. - beta1**iter)
        #sqr_next = sqr / (1. - nd.power(beta2, iter))
        sqr_next = sqr / (1. - beta2**iter)
        g_next = lr * val_next / (nd.sqrt(sqr_next) + eps_stable)
        param[:] -= g_next
import DCGAN as dcgan
from mxnet import nd
from mxnet import random
from matplotlib import pyplot as plt
import numpy as np

#if not updating the seed by system time, you'll get the same results
import time
seed = int(time.time() * 100)
random.seed(seed)

import sys
sys.path.append('./dependencies')
import utils
ctx = utils.try_gpu()

filename = './params/dcgan.netG.save'
netG = dcgan.Generator()
netG.collect_params()
netG.load_params(filename, ctx=ctx)

z = nd.random_normal(0, 1, shape=(4, 100, 1, 1), ctx=ctx)
#print(z)
output = netG(z)

for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(((output[i].asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(
        np.uint8))
    plt.axis('off')
plt.show()
Пример #5
0
def explain_cgan(args, config):
    ctx = ress(make_ctx, config.device, config.device_id)

    if config.log:
        mkfilelogger('ecGAN', config.sub('log'),
                     logging.DEBUG if config.get('debug') else logging.INFO)

    model = models[config.model.type](ctx=ctx,
                                      config=config,
                                      **config.model.kwargs)
    if config.use_pattern:
        model.load_pattern_params()

    num = config.explanation.batch_size
    K = args.classnum
    net_desc = config.nets[config.model.kwargs.discriminator]
    netnam = '%s<%s>' % (net_desc.name, net_desc.type)
    net_epoch = net_desc.epoch
    ohckw = dict(zip(['off_value', 'on_value'], config.get('ohcond', [0, 1])))
    if config.use_pattern:
        templ = config.pattern.output
    else:
        templ = config.explanation.output

    if args.seed:
        random.seed(args.seed)

    for i in range(config.explanation.iterations):

        comkw = dict(iter=i, net=netnam, net_epoch=net_epoch)
        fpath = config.exsub(templ,
                             data_desc='result<%s>' % config.sampler.type,
                             ftype='h5',
                             **comkw)
        if not config.overwrite and os.path.isfile(fpath):
            getLogger('ecGAN').info('File already exits, skipping \'%s\'...',
                                    fpath)
            continue

        noise, cond = samplers[config.sampler.type](num, K, ctx, ohkw=ohckw)

        args = [K] + asim(noise, cond)
        if config.use_pattern:
            kwargs = dict(
                single_out=config.pattern.get('single_out', True),
                attribution=config.pattern.get('type') == 'attribution')
            s_gen, gen = model.explain_pattern_top(*args, **kwargs)
            s_noise, s_cond = model.explain_pattern(*args, **kwargs)
        else:
            kwargs = dict(single_out=config.explanation.get(
                'single_out', False),
                          mkwargs=config.explanation.get('kwargs', {}))
            s_gen, gen = model.explain_top(*args, **kwargs)
            s_noise, s_cond = model.explain(*args, **kwargs)

        info = {
            'input/noise': noise.asnumpy(),
            'input/cond': cond.asnumpy(),
            'generated': gen.asnumpy(),
            'prediction': model._out.argmax(axis=1).asnumpy(),
            'out': model._out.asnumpy(),
            'label': cond.argmax(axis=1).asnumpy(),
            'relevance/noise': s_noise.squeeze().asnumpy(),
            'relevance/cond': s_cond.squeeze().asnumpy(),
            'relevance/generated': s_gen.asnumpy(),
        }
        save_data_h5(info, fpath)
Пример #6
0
def explain_clss(args, config):
    ctx = ress(make_ctx, config.device, config.device_id)

    if config.log:
        mkfilelogger('ecGAN', config.sub('log'),
                     logging.DEBUG if config.get('debug') else logging.INFO)

    model = models[config.model.type](ctx=ctx,
                                      config=config,
                                      **config.model.kwargs)
    if config.use_pattern:
        model.load_pattern_params()

    dataset = ress(data_funcs[config.data.func],
                   *(config.data.args),
                   ctx=ctx,
                   **(config.data.kwargs))
    data_iter = gluon.data.DataLoader(dataset,
                                      config.explanation.batch_size,
                                      shuffle=False,
                                      last_batch='discard')

    K = len(dataset.classes)

    net_desc = config.nets[config.model.kwargs.classifier]
    netnam = '%s<%s>' % (net_desc.name, net_desc.type)
    net_epoch = net_desc.epoch
    if config.use_pattern:
        templ = config.pattern.output
    else:
        templ = config.explanation.output

    if args.seed:
        random.seed(args.seed)

    for i, (data, label) in enumerate(data_iter):
        if i >= config.explanation.iterations:
            break

        comkw = dict(iter=i, net=netnam, net_epoch=net_epoch)
        fpath = config.exsub(templ,
                             data_desc='result<%s>' % config.data.func,
                             ftype='h5',
                             **comkw)
        if not config.overwrite and os.path.isfile(fpath):
            getLogger('ecGAN').info('File already exits, skipping \'%s\'...',
                                    fpath)
            continue

        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)

        if config.use_pattern:
            relevance = model.explain_pattern(
                data,
                cond=label.squeeze(),
                single_out=config.pattern.get('single_out', True),
                attribution=config.pattern.get('type') == 'attribution')
        else:
            relevance = model.explain(
                data,
                cond=label.squeeze(),
                single_out=config.explanation.get('single_out', False),
                mkwargs=config.explanation.get('kwargs', {}))

        info = {
            'input': data.asnumpy(),
            'prediction': model._out.argmax(axis=1).asnumpy(),
            'out': model._out.asnumpy(),
            'label': label.asnumpy(),
            'relevance': relevance.asnumpy(),
        }
        save_data_h5(info, fpath)

    del model