コード例 #1
0
def main():
    # =============setting============
    dataset = config.dataset.dataset
    batch_size = config.TRAIN.BATCH_SIZE
    Z = 100
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    assert len(ctx) == 1
    ctx = ctx[0]
    epoch = config.TEST.TEST_EPOCH

    logger, final_output_path = create_logger(config.output_path, args.cfg)
    prefix = os.path.join(final_output_path, config.TRAIN.model_prefix)
    test_fig_path = os.path.join(final_output_path, 'test_fig')

    if not os.path.exists(test_fig_path):
        os.makedirs(test_fig_path)

    test_fig_prefix = os.path.join(test_fig_path, dataset)

    mx.random.seed(config.RNG_SEED)
    np.random.seed(config.RNG_SEED)

    # ==============data==============
    if dataset == 'mnist':
        X_train, X_test = get_mnist()
        test_iter = mx.io.NDArrayIter(X_test, batch_size=batch_size)
    else:
        raise NotImplemented

    rand_iter = RandIter(batch_size, Z)

    # print config
    pprint.pprint(config)
    # logger.info('system:{}'.format(os.uname()))
    # logger.info('mxnet path:{}'.format(mx.__file__))
    # logger.info('rng seed:{}'.format(config.RNG_SEED))
    # logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # =============Generator Module=============
    generatorSymbol = get_symbol_generator()
    generator = mx.mod.Module(symbol=generatorSymbol, data_names=('rand',), label_names=None, context=ctx)
    generator.bind(data_shapes=rand_iter.provide_data)
    generator.load_params(prefix + '-generator-%04d.params' % epoch)

    test_iter.reset()
    batch = test_iter.next()
    rbatch = rand_iter.next()
    generator.forward(rbatch, is_train=False)
    outG = generator.get_outputs()
    visualize(outG[0].asnumpy(), batch.data[0].asnumpy(), test_fig_prefix + '-test-%04d.png' % epoch)
コード例 #2
0
ファイル: train_pix2pix.py プロジェクト: zhengxiawu/FGIR-GAN
def main():
    # set debug
    DEBUG = False
    # =============setting============
    dataset = config.dataset.dataset
    batch_size = config.TRAIN.BATCH_SIZE
    lr = config.TRAIN.lr
    beta1 = config.TRAIN.beta1
    sigma = 0.02
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    assert len(ctx) == 1, 'Multi GPU not supported.'
    ctx = ctx[0]
    frequent = config.default.frequent
    check_point = True

    logger, final_output_path = create_logger(config.output_path, args.cfg)
    prefix = os.path.join(final_output_path, config.TRAIN.model_prefix)
    train_fig_path = os.path.join(final_output_path, 'train_fig')

    train_fig_prefix = os.path.join(train_fig_path, dataset)

    if not os.path.exists(train_fig_path):
        os.makedirs(train_fig_path)

    # set random seed for reproducibility
    mx.random.seed(config.RNG_SEED)
    np.random.seed(config.RNG_SEED)

    # ==============data==============
    train_data = pix2pixIter(config, shuffle=True, ctx=ctx)

    step = config.TRAIN.step_epoch * train_data.size / batch_size
    step_decay = config.TRAIN.decay_epoch * train_data.size / batch_size
    if config.TRAIN.end_epoch == (config.TRAIN.step_epoch +
                                  config.TRAIN.decay_epoch):
        lr_scheduler_g = PIX2PIXScheduler(step=int(step),
                                          step_decay=int(step_decay),
                                          base_lr=lr)
        lr_scheduler_d = PIX2PIXScheduler(step=int(step),
                                          step_decay=int(step_decay),
                                          base_lr=lr / 2.0)
    else:
        lr_scheduler_g = None
        lr_scheduler_d = None

    label = mx.nd.zeros((batch_size, ), ctx=ctx)

    # print config
    pprint.pprint(config)
    logger.info('system:{}'.format(os.uname()))
    logger.info('mxnet path:{}'.format(mx.__file__))
    logger.info('rng seed:{}'.format(config.RNG_SEED))
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # =============Generator Module=============
    if batch_size == 1:
        if config.netG == 'autoencoder':
            generatorSymbol = defineG_encoder_decoder(config)
        elif config.netG == 'unet':
            generatorSymbol = defineG_unet(config)
        else:
            raise NotImplemented
    else:
        if config.netG == 'autoencoder':
            generatorSymbol = defineG_encoder_decoder_batch(config)
        elif config.netG == 'unet':
            generatorSymbol = defineG_unet_batch(config)
        else:
            raise NotImplemented

    if DEBUG:
        generatorGroup = generatorSymbol.get_internals()
        name_list = generatorGroup.list_outputs()
        out_name = []
        for name in name_list:
            if 'output' in name:
                out_name += [generatorGroup[name]]
        out_group = mx.sym.Group(out_name)
        out_shapes = out_group.infer_shape(A=(4, 3, 256, 256))

    generator = mx.mod.Module(symbol=generatorSymbol,
                              data_names=(
                                  'A',
                                  'B',
                              ),
                              label_names=None,
                              context=ctx)
    generator.bind(data_shapes=train_data.provide_data)
    #draw network
    #network_test(generatorSymbol)
    # init params
    arg_params = {}
    aux_params = {}
    arg_names = generatorSymbol.list_arguments()
    aux_names = generatorSymbol.list_auxiliary_states()
    arg_shapes, _, aux_shapes = generatorSymbol.infer_shape(
        A=train_data.provide_data[0][1], B=train_data.provide_data[1][1])

    if batch_size == 1:
        for idx, arg_name in enumerate(arg_names):
            if 'weight' in arg_name:
                arg_params[arg_name] = mx.random.normal(0.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'gamma' in arg_name:
                arg_params[arg_name] = mx.random.normal(1.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'bias' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            elif 'beta' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            else:
                # raise NameError('Unknown parameter name.')
                pass
    else:
        for idx, arg_name in enumerate(arg_names):
            if 'weight' in arg_name:
                arg_params[arg_name] = mx.random.normal(0.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'gamma' in arg_name:
                arg_params[arg_name] = mx.random.normal(1.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'bias' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            elif 'beta' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            else:
                # raise NameError('Unknown parameter name.')
                pass
        for idx, aux_name in enumerate(aux_names):
            if 'mean' in aux_name:
                aux_params[aux_name] = mx.nd.zeros(shape=aux_shapes[idx])
            elif 'var' in aux_name:
                aux_params[aux_name] = mx.nd.ones(shape=aux_shapes[idx])
            else:
                raise NameError('Unknown aux_name.')

    generator.init_params(arg_params=arg_params, aux_params=aux_params)

    if lr_scheduler_g is not None:
        generator.init_optimizer(optimizer='adam',
                                 optimizer_params={
                                     'learning_rate': lr,
                                     'lr_scheduler': lr_scheduler_g,
                                     'beta1': beta1,
                                     'rescale_grad': 1.0 / batch_size
                                 })
    else:
        generator.init_optimizer(optimizer='adam',
                                 optimizer_params={
                                     'learning_rate': lr,
                                     'beta1': beta1,
                                     'rescale_grad': 1.0 / batch_size
                                 })
    mods = [generator]

    # =============Discriminator Module=============
    if batch_size == 1:
        if config.netD == 'basic':
            discriminatorSymbol = defineD_basic()
        elif config.netD == 'n_layers':
            discriminatorSymbol = defineD_n_layers(n_layers=config.n_layers)
        else:
            raise NotImplemented
    else:
        if config.netD == 'basic':
            discriminatorSymbol = defineD_basic_batch(batch_size=batch_size)
        elif config.netD == 'n_layers':
            discriminatorSymbol = defineD_n_layers_batch(
                n_layers=config.n_layers, batch_size=batch_size)
        else:
            raise NotImplemented

    if DEBUG:
        generatorGroup = discriminatorSymbol.get_internals()
        name_list = generatorGroup.list_outputs()
        out_name = []
        for name in name_list:
            if 'output' in name:
                out_name += [generatorGroup[name]]
        out_group = mx.sym.Group(out_name)
        out_shapes = out_group.infer_shape(A=(1, 3, 256, 256),
                                           B=(1, 3, 256, 256))

    discriminator = mx.mod.Module(symbol=discriminatorSymbol,
                                  data_names=(
                                      'A',
                                      'B',
                                  ),
                                  label_names=('label', ),
                                  context=ctx)
    discriminator.bind(data_shapes=train_data.provide_data,
                       label_shapes=[('label', (batch_size, ))],
                       inputs_need_grad=True)

    # init params
    arg_params = {}
    aux_params = {}
    arg_names = discriminatorSymbol.list_arguments()
    aux_names = discriminatorSymbol.list_auxiliary_states()
    arg_shapes, _, aux_shapes = discriminatorSymbol.infer_shape(
        A=train_data.provide_data[0][1],
        B=train_data.provide_data[1][1],
        label=(batch_size, ))

    if batch_size == 1:
        for idx, arg_name in enumerate(arg_names):
            if 'weight' in arg_name:
                arg_params[arg_name] = mx.random.normal(0.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'gamma' in arg_name:
                arg_params[arg_name] = mx.random.normal(1.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'bias' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            elif 'beta' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            else:
                # raise NameError('Unknown parameter name.')
                pass
    else:
        for idx, arg_name in enumerate(arg_names):
            if 'weight' in arg_name:
                arg_params[arg_name] = mx.random.normal(0.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'gamma' in arg_name:
                arg_params[arg_name] = mx.random.normal(1.0,
                                                        sigma,
                                                        shape=arg_shapes[idx])
            elif 'bias' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            elif 'beta' in arg_name:
                arg_params[arg_name] = mx.nd.zeros(shape=arg_shapes[idx])
            else:
                # raise NameError('Unknown parameter name.')
                pass
        for idx, aux_name in enumerate(aux_names):
            if 'mean' in aux_name:
                aux_params[aux_name] = mx.nd.zeros(shape=aux_shapes[idx])
            elif 'var' in aux_name:
                aux_params[aux_name] = mx.nd.ones(shape=aux_shapes[idx])
            else:
                raise NameError('Unknown aux_name.')

    discriminator.init_params(arg_params=arg_params, aux_params=aux_params)

    # gradient is scaled in LogisticRegression layer, no need to rescale gradient
    if lr_scheduler_d is not None:
        discriminator.init_optimizer(optimizer='adam',
                                     optimizer_params={
                                         'learning_rate': lr / 2.0,
                                         'lr_scheduler': lr_scheduler_d,
                                         'beta1': beta1,
                                         'rescale_grad': 1.0
                                     })
    else:
        discriminator.init_optimizer(optimizer='adam',
                                     optimizer_params={
                                         'learning_rate': lr / 2.0,
                                         'beta1': beta1,
                                         'rescale_grad': 1.0
                                     })
    mods.append(discriminator)

    # metric
    mG = metric.CrossEntropyMetric()
    mD = metric.CrossEntropyMetric()
    mACC = metric.AccMetric()
    mL1 = metric.L1LossMetric(config)

    t_accumulate = 0

    # =============train===============
    for epoch in range(config.TRAIN.end_epoch):
        train_data.reset()
        mACC.reset()
        mG.reset()
        mD.reset()
        mL1.reset()
        for t, batch in enumerate(train_data):

            t_start = time.time()

            # generator input real A, output fake B
            generator.forward(batch, is_train=True)
            outG = generator.get_outputs()

            # update discriminator on fake
            # discriminator input real A and fake B
            # want discriminator to predict fake (0)
            label[:] = 0
            discriminator.forward(mx.io.DataBatch([batch.data[0], outG[1]],
                                                  [label]),
                                  is_train=True)
            discriminator.backward()
            gradD = [[grad.copyto(grad.context) for grad in grads]
                     for grads in discriminator._exec_group.grad_arrays]

            discriminator.update_metric(mD, [label])
            discriminator.update_metric(mACC, [label])

            # update discriminator on real
            # discriminator input real A and real B
            # want discriminator to predict real (1)
            label[:] = 1
            batch.label = [label]
            discriminator.forward(batch, is_train=True)
            discriminator.backward()
            for gradsr, gradsf in zip(discriminator._exec_group.grad_arrays,
                                      gradD):
                for gradr, gradf in zip(gradsr, gradsf):
                    # gradr =  (gradr + gradf)/2
                    gradr += gradf
            discriminator.update()

            discriminator.update_metric(mD, [label])
            discriminator.update_metric(mACC, [label])

            # update generator
            # discriminator input real A and fake B
            # want discriminator to predict real (1)
            label[:] = 1
            discriminator.forward(mx.io.DataBatch([batch.data[0], outG[1]],
                                                  [label]),
                                  is_train=True)
            discriminator.backward()
            diffD = discriminator.get_input_grads()
            # loss does not need output gradient
            generator.backward([
                mx.nd.array(np.ones((batch_size, )), ctx=ctx),
                diffD[1] * config.GAN_loss
            ])
            generator.update()

            mG.update([label], discriminator.get_outputs())
            mL1.update(None, outG)

            t_accumulate += time.time() - t_start

            t += 1
            if t % frequent == 0:
                if config.TRAIN.batch_end_plot_figure:
                    visualize(
                        batch.data[0].asnumpy(), batch.data[1].asnumpy(),
                        outG[1].asnumpy(), train_fig_prefix +
                        '-train-%04d-%06d.png' % (epoch + 1, t))
                print 'Epoch[{}] Batch[{}] Time[{:.4f}] dACC: {:.4f} gCE: {:.4f} dCE: {:.4f} gL1: {:.4f}'.format(
                    epoch, t, t_accumulate,
                    mACC.get()[1],
                    mG.get()[1],
                    mD.get()[1],
                    mL1.get()[1])
                logger.info(
                    'Epoch[{}] Batch[{}] Speed[{:.4f} batch/s] dACC: {:.4f} gCE: {:.4f} dCE: {:.4f} gL1: {:.4f}\n'
                    .format(epoch, t, frequent * batch_size / t_accumulate,
                            mACC.get()[1],
                            mG.get()[1],
                            mD.get()[1],
                            mL1.get()[1]))
                t_accumulate = 0

        if check_point:
            print('Saving...')
            if config.TRAIN.epoch_end_plot_figure:
                visualize(batch.data[0].asnumpy(), batch.data[1].asnumpy(),
                          outG[1].asnumpy(),
                          train_fig_prefix + '-train-%04d.png' % (epoch + 1))
            if (epoch + 1) % config.TRAIN.save_interval == 0:
                generator.save_params(prefix + '-generator-%04d.params' %
                                      (epoch + 1))
                discriminator.save_params(prefix +
                                          '-discriminator-%04d.params' %
                                          (epoch + 1))

    generator.save_params(prefix +
                          '-generator-%04d.params' % config.TRAIN.end_epoch)
    discriminator.save_params(prefix + '-discriminator-%04d.params' %
                              config.TRAIN.end_epoch)
コード例 #3
0
ファイル: train_cgan.py プロジェクト: zmj93GitHub/MXNet-GAN
def main():
    # =============setting============
    dataset = config.dataset.dataset
    batch_size = config.TRAIN.BATCH_SIZE
    Z = 100
    num_classes = 10
    lr = config.TRAIN.lr
    beta1 = config.TRAIN.beta1
    sigma = 0.02
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    assert len(ctx) == 1
    ctx = ctx[0]
    frequent = config.default.frequent
    check_point = True

    logger, final_output_path = create_logger(config.output_path, args.cfg)
    prefix = os.path.join(final_output_path, config.TRAIN.model_prefix)
    train_fig_path = os.path.join(final_output_path, 'train_fig')

    train_fig_prefix = os.path.join(train_fig_path, dataset)

    if not os.path.exists(train_fig_path):
        os.makedirs(train_fig_path)

    mx.random.seed(config.RNG_SEED)
    np.random.seed(config.RNG_SEED)

    # ==============data==============
    if dataset == 'mnist':
        X_train, X_test, Y_train, Y_test = get_mnist()
        train_iter = mx.io.NDArrayIter(X_train,
                                       label=Y_train,
                                       batch_size=batch_size)
    else:
        raise NotImplemented

    rand_iter = RandIter(batch_size, Z)
    label = mx.nd.zeros((batch_size, ), ctx=ctx)

    # print config
    pprint.pprint(config)
    logger.info('system:{}'.format(os.uname()))
    logger.info('mxnet path:{}'.format(mx.__file__))
    logger.info('rng seed:{}'.format(config.RNG_SEED))
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # =============Generator Module=============
    generatorSymbol = get_symbol_generator()
    generator = mx.mod.Module(symbol=generatorSymbol,
                              data_names=(
                                  'class_label',
                                  'rand',
                              ),
                              label_names=None,
                              context=ctx)
    generator.bind(data_shapes=[('class_label', (batch_size, num_classes))] +
                   rand_iter.provide_data)
    generator.init_params(initializer=mx.init.Normal(sigma))
    generator.init_optimizer(optimizer='adam',
                             optimizer_params={
                                 'learning_rate': lr,
                                 'beta1': beta1,
                             })
    mods = [generator]

    # =============Discriminator Module=============
    discriminatorSymbol = get_symbol_discriminator()
    discriminator = mx.mod.Module(symbol=discriminatorSymbol,
                                  data_names=(
                                      'data',
                                      'class_label',
                                  ),
                                  label_names=('label', ),
                                  context=ctx)
    discriminator.bind(data_shapes=train_iter.provide_data +
                       [('class_label', (batch_size, num_classes))],
                       label_shapes=[('label', (batch_size, ))],
                       inputs_need_grad=True)
    discriminator.init_params(initializer=mx.init.Normal(sigma))
    discriminator.init_optimizer(optimizer='adam',
                                 optimizer_params={
                                     'learning_rate': lr,
                                     'beta1': beta1,
                                 })
    mods.append(discriminator)

    # metric
    # mG = mx.metric.CustomMetric(metric.fentropy)
    # mD = mx.metric.CustomMetric(metric.fentropy)
    # mACC = mx.metric.CustomMetric(metric.facc)
    # test_metric = metric.CrossEntropyMetric()
    # test_metric.reset()
    mG = metric.CrossEntropyMetric()
    mD = metric.CrossEntropyMetric()
    mACC = metric.AccMetric()

    # =============train===============
    for epoch in range(config.TRAIN.end_epoch):
        train_iter.reset()
        mACC.reset()
        mG.reset()
        mD.reset()
        for t, batch in enumerate(train_iter):
            rbatch = rand_iter.next()

            batch_label_one_hot = np.zeros((batch_size, num_classes),
                                           dtype=np.float32)
            batch_label_np = batch.label[0].asnumpy()
            for i in xrange(batch_size):
                batch_label_one_hot[i, int(batch_label_np[i])] = 1
            batch_label_one_hot = mx.nd.array(batch_label_one_hot)

            generator.forward(mx.io.DataBatch([batch_label_one_hot] +
                                              rbatch.data, []),
                              is_train=True)
            outG = generator.get_outputs()

            # update discriminator on fake
            label[:] = 0
            discriminator.forward(mx.io.DataBatch(outG + [batch_label_one_hot],
                                                  [label]),
                                  is_train=True)
            discriminator.backward()
            gradD = [[grad.copyto(grad.context) for grad in grads]
                     for grads in discriminator._exec_group.grad_arrays]

            discriminator.update_metric(mD, [label])
            discriminator.update_metric(mACC, [label])
            # test_metric.update([label], discriminator.get_outputs())

            # update discriminator on real
            label[:] = 1
            batch.label = [label]
            discriminator.forward(mx.io.DataBatch(
                batch.data + [batch_label_one_hot], [label]),
                                  is_train=True)
            discriminator.backward()
            for gradsr, gradsf in zip(discriminator._exec_group.grad_arrays,
                                      gradD):
                for gradr, gradf in zip(gradsr, gradsf):
                    gradr += gradf
            discriminator.update()

            discriminator.update_metric(mD, [label])
            discriminator.update_metric(mACC, [label])
            # test_metric.update([label], discriminator.get_outputs())

            # update generator
            label[:] = 1
            discriminator.forward(mx.io.DataBatch(outG, [label]),
                                  is_train=True)
            discriminator.backward()
            diffD = discriminator.get_input_grads()
            generator.backward(diffD)
            generator.update()

            mG.update([label], discriminator.get_outputs())

            t += 1
            if t % frequent == 0:
                # visualize(outG[0].asnumpy(), batch.data[0].asnumpy())
                print 'Epoch[{}] Batch[{}] dACC: {:.4f} gCE: {:.4f} dCE: {:.4f}'.format(
                    epoch, t,
                    mACC.get()[1],
                    mG.get()[1],
                    mD.get()[1])
                logger.info(
                    'Epoch[{}] Batch[{}] dACC: {:.4f} gCE: {:.4f} dCE: {:.4f}\n'
                    .format(epoch, t,
                            mACC.get()[1],
                            mG.get()[1],
                            mD.get()[1]))

        if check_point:
            print('Saving...')
            visualize(outG[0].asnumpy(), batch.data[0].asnumpy(),
                      train_fig_prefix + '-train-%04d.png' % (epoch + 1))
            generator.save_params(prefix + '-generator-%04d.params' %
                                  (epoch + 1))
            discriminator.save_params(prefix + '-discriminator-%04d.params' %
                                      (epoch + 1))