Ejemplo n.º 1
0
def test_log_scalar_summary():
    logdir = './experiment/scalar'
    writer = FileWriter(logdir)
    for i in range(10):
        s = scalar('scalar', i)
        writer.add_summary(s, i + 1)
    writer.flush()
    writer.close()
Ejemplo n.º 2
0
def test_log_histogram_summary():
    logdir = './experiment/histogram'
    writer = FileWriter(logdir)
    for i in range(10):
        mu, sigma = i * 0.1, 1.0
        values = np.random.normal(mu, sigma,
                                  10000)  # larger for better looking.
        hist = summary.histogram('discrete_normal', values)
        writer.add_summary(hist, i + 1)
    writer.flush()
    writer.close()
Ejemplo n.º 3
0
def test_log_image_summary():
    logdir = './experiment/image'
    writer = FileWriter(logdir)

    path = 'http://yann.lecun.com/exdb/mnist/'
    (train_lbl, train_img) = read_data(path + 'train-labels-idx1-ubyte.gz',
                                       path + 'train-images-idx3-ubyte.gz')

    for i in range(10):
        tensor = np.reshape(train_img[i], (28, 28, 1))
        im = summary.image(
            'mnist/' + str(i),
            tensor)  # in this case, images are grouped under `mnist` tag.
        writer.add_summary(im, i + 1)
    writer.flush()
    writer.close()
Ejemplo n.º 4
0
def train():
    '''train wgan
    '''
    ctxs = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    batch_size = args.batch_size
    z_dim = args.z_dim
    lr = args.lr
    epoches = args.epoches
    wclip = args.wclip
    frequency = args.frequency
    model_prefix = args.model_prefix
    rand_iter = RandIter(batch_size, z_dim)
    image_iter = ImageIter(args.data_path, batch_size, (3, 64, 64))
    # G and D
    symG, symD = dcgan64x64(ngf=args.ngf, ndf=args.ndf, nc=args.nc)
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctxs)
    modG.bind(data_shapes=rand_iter.provide_data)
    modG.init_params(initializer=mx.init.Normal(0.002))
    modG.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=None, context=ctxs)
    modD.bind(data_shapes=image_iter.provide_data,
              inputs_need_grad=True)
    modD.init_params(mx.init.Normal(0.002))
    modD.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    # train
    logging.info('Start training')
    metricD = WGANMetric()
    metricG = WGANMetric()
    fix_noise_batch = mx.io.DataBatch([mx.random.normal(0, 1, shape=(batch_size, z_dim, 1, 1))], [])
    # visualization with TensorBoard if possible
    if use_tb:
        writer = FileWriter('tmp/exp')
    for epoch in range(epoches):
        image_iter.reset()
        metricD.reset()
        metricG.reset()
        for i, batch in enumerate(image_iter):
            # clip weight
            for params in modD._exec_group.param_arrays:
                for param in params:
                    mx.nd.clip(param, -wclip, wclip, out=param)
            # forward G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            # fake
            modD.forward(mx.io.DataBatch(outG, label=[]), is_train=True)
            fw_g = modD.get_outputs()[0].asnumpy()
            modD.backward([mx.nd.ones((batch_size, 1)) / batch_size])
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]
            # real
            modD.forward(batch, is_train=True)
            fw_r = modD.get_outputs()[0].asnumpy()
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            for grads_real, grads_fake in zip(modD._exec_group.grad_arrays, gradD):
                for grad_real, grad_fake in zip(grads_real, grads_fake):
                    grad_real += grad_fake
            modD.update()
            errorD = -(fw_r - fw_g) / batch_size
            metricD.update(errorD.mean())
            # update G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            modD.forward(mx.io.DataBatch(outG, []), is_train=True)
            errorG = -modD.get_outputs()[0] / batch_size
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            modG.backward(modD.get_input_grads())
            modG.update()
            metricG.update(errorG.asnumpy().mean())
            # logging state
            if (i+1)%frequency == 0:
                print("epoch:", epoch+1, "iter:", i+1, "G: ", metricG.get(), "D: ", metricD.get())
        # save checkpoint
        modG.save_checkpoint('model/%s-G'%(model_prefix), epoch+1)
        modD.save_checkpoint('model/%s-D'%(model_prefix), epoch+1)
        rbatch = rand_iter.next()
        modG.forward(rbatch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-rand-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]  # BGR -> RGB
            writer.add_summary(summary.image('gout-rand-%d'%(epoch+1), canvas))
        modG.forward(fix_noise_batch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-fix-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]
            writer.add_summary(summary.image('gout-fix-%d'%(epoch+1), canvas))
        if use_tb:
            writer.flush()
    writer.close()