Ejemplo n.º 1
0
def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
                     count, image_dir, summary_writer):
    num = cfg.TRAIN.VIS_COUNT

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(
        real_img, '%s/real_samples.png' % (image_dir),
        normalize=True)
    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
    summary_writer.add_summary(sup_real_img, count)

    for i in range(num_imgs):
        fake_img = fake_imgs[i][0:num]
        # The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
        # is still [-1. 1]...
        vutils.save_image(
            fake_img.data, '%s/count_%09d_fake_samples%d.png' %
            (image_dir, count, i), normalize=True)

        fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()

        fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
        fake_img_set = (fake_img_set + 1) * 255 / 2
        fake_img_set = fake_img_set.astype(np.uint8)

        sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
        summary_writer.add_summary(sup_fake_img, count)
        summary_writer.flush()
Ejemplo n.º 2
0
def save_img_results(imgs_tcpu, fake_imgs, num_imgs, count, image_dir,
                     summary_writer):
    num = cfg.TRAIN.VIS_COUNT

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(real_img,
                      '%s/real_samples.png' % (image_dir),
                      normalize=True)
    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
    summary_writer.add_summary(sup_real_img, count)

    for i in range(num_imgs):
        fake_img = fake_imgs[i][0:num]
        # The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
        # is still [-1. 1]...
        vutils.save_image(fake_img.data,
                          '%s/count_%09d_fake_samples%d.png' %
                          (image_dir, count, i),
                          normalize=True)

        fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()

        fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
        fake_img_set = (fake_img_set + 1) * 255 / 2
        fake_img_set = fake_img_set.astype(np.uint8)

        sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
        summary_writer.add_summary(sup_fake_img, count)
        summary_writer.flush()
Ejemplo n.º 3
0
def save_img_results(imgs_tcpu, fake_imgs, num_imgs, count, image_dir,
                     summary_writer, rec_ids, im_ids):
    num = cfg.TRAIN.VIS_COUNT
    last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Image/'

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(real_img,
                      '%s/count_%09d_real_samples.png' % (image_dir, count),
                      normalize=True)

    vutils.save_image(real_img,
                      last_run_dir + 'real_samples.png',
                      normalize=True)

    # write images and recipe IDs to filenames
    rec_ids = [t.tostring().decode('UTF-8') for t in rec_ids.numpy()]
    im_ids = [t.tostring().decode('UTF-8') for t in im_ids.numpy()]
    with open('%s/count_%09d_real_samples_IDs.txt' % (image_dir, count),
              "w") as f:
        for rec_id, im_id in zip(rec_ids, im_ids):
            f.write("rec_id=%s, img_id=%s\n" % (rec_id, im_id))

    with open(last_run_dir + 'real_samples_IDs.txt', "w") as f:
        for rec_id, im_id in zip(rec_ids, im_ids):
            f.write("rec_id=%s, img_id=%s\n" % (rec_id, im_id))

    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
    summary_writer.add_summary(sup_real_img, count)

    for i in range(num_imgs):
        fake_img = fake_imgs[i][0:num]
        # The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
        # is still [-1. 1]...
        vutils.save_image(fake_img.data,
                          '%s/count_%09d_fake_samples%d.png' %
                          (image_dir, count, i),
                          normalize=True)

        vutils.save_image(fake_img.data,
                          last_run_dir + 'fake_samples%d.png' % (i),
                          normalize=True)

        fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()

        fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
        fake_img_set = (fake_img_set + 1) * 255 / 2
        fake_img_set = fake_img_set.astype(np.uint8)

        sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
        summary_writer.add_summary(sup_fake_img, count)
        summary_writer.flush()
Ejemplo n.º 4
0
    def add_images(self, tag, img_tensor, global_step=None, walltime=None):
        """Add batched image data to summary.
        Note that this requires the ``pillow`` package.
        Args:
            tag (string): Data identifier
            img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
              seconds after epoch of event
            dataformats (string): Image data format specification of the form
              NCHW, NHWC, CHW, HWC, HW, WH, etc.
        Shape:
            img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
            accepted. e.g. NCHW or NHWC.
        Examples::
            from torch.utils.tensorboard import SummaryWriter
            import numpy as np
            img_batch = np.zeros((16, 3, 100, 100))
            for i in range(16):
                img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
                img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
            writer = SummaryWriter()
            writer.add_images('my_image_batch', img_batch, 0)
            writer.close()
        Expected result:
        .. image:: _static/img/tensorboard/add_images.png
           :scale: 30 %
        """

        self._get_file_writer().add_summary(
            summary.image(tag, img_tensor, max_outputs=3), global_step,
            walltime)
Ejemplo n.º 5
0
def save_real(imgs_tcpu, image_dir):
    num = cfg.TRAIN.VIS_COUNT

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(real_img,
                      '%s/real_samples.png' % (image_dir),
                      normalize=True)
    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
Ejemplo n.º 6
0
def test_log_image_summary():
    logdir = './experiment/image'
    writer = FileWriter(logdir)

    path = 'http://yann.lecun.com/exdb/mnist/'
    (train_lbl, train_img) = read_data(path + 'train-labels-idx1-ubyte.gz',
                                       path + 'train-images-idx3-ubyte.gz')

    for i in range(10):
        tensor = np.reshape(train_img[i], (28, 28, 1))
        im = summary.image(
            'mnist/' + str(i),
            tensor)  # in this case, images are grouped under `mnist` tag.
        writer.add_summary(im, i + 1)
    writer.flush()
    writer.close()
Ejemplo n.º 7
0
def train():
    '''train wgan
    '''
    ctxs = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    batch_size = args.batch_size
    z_dim = args.z_dim
    lr = args.lr
    epoches = args.epoches
    wclip = args.wclip
    frequency = args.frequency
    model_prefix = args.model_prefix
    rand_iter = RandIter(batch_size, z_dim)
    image_iter = ImageIter(args.data_path, batch_size, (3, 64, 64))
    # G and D
    symG, symD = dcgan64x64(ngf=args.ngf, ndf=args.ndf, nc=args.nc)
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctxs)
    modG.bind(data_shapes=rand_iter.provide_data)
    modG.init_params(initializer=mx.init.Normal(0.002))
    modG.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=None, context=ctxs)
    modD.bind(data_shapes=image_iter.provide_data,
              inputs_need_grad=True)
    modD.init_params(mx.init.Normal(0.002))
    modD.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    # train
    logging.info('Start training')
    metricD = WGANMetric()
    metricG = WGANMetric()
    fix_noise_batch = mx.io.DataBatch([mx.random.normal(0, 1, shape=(batch_size, z_dim, 1, 1))], [])
    # visualization with TensorBoard if possible
    if use_tb:
        writer = FileWriter('tmp/exp')
    for epoch in range(epoches):
        image_iter.reset()
        metricD.reset()
        metricG.reset()
        for i, batch in enumerate(image_iter):
            # clip weight
            for params in modD._exec_group.param_arrays:
                for param in params:
                    mx.nd.clip(param, -wclip, wclip, out=param)
            # forward G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            # fake
            modD.forward(mx.io.DataBatch(outG, label=[]), is_train=True)
            fw_g = modD.get_outputs()[0].asnumpy()
            modD.backward([mx.nd.ones((batch_size, 1)) / batch_size])
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]
            # real
            modD.forward(batch, is_train=True)
            fw_r = modD.get_outputs()[0].asnumpy()
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            for grads_real, grads_fake in zip(modD._exec_group.grad_arrays, gradD):
                for grad_real, grad_fake in zip(grads_real, grads_fake):
                    grad_real += grad_fake
            modD.update()
            errorD = -(fw_r - fw_g) / batch_size
            metricD.update(errorD.mean())
            # update G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            modD.forward(mx.io.DataBatch(outG, []), is_train=True)
            errorG = -modD.get_outputs()[0] / batch_size
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            modG.backward(modD.get_input_grads())
            modG.update()
            metricG.update(errorG.asnumpy().mean())
            # logging state
            if (i+1)%frequency == 0:
                print("epoch:", epoch+1, "iter:", i+1, "G: ", metricG.get(), "D: ", metricD.get())
        # save checkpoint
        modG.save_checkpoint('model/%s-G'%(model_prefix), epoch+1)
        modD.save_checkpoint('model/%s-D'%(model_prefix), epoch+1)
        rbatch = rand_iter.next()
        modG.forward(rbatch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-rand-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]  # BGR -> RGB
            writer.add_summary(summary.image('gout-rand-%d'%(epoch+1), canvas))
        modG.forward(fix_noise_batch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-fix-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]
            writer.add_summary(summary.image('gout-fix-%d'%(epoch+1), canvas))
        if use_tb:
            writer.flush()
    writer.close()