示例#1
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--data_path',
                        type=str,
                        default="data/datasets/cifar10/train",
                        help='dataset directory path')
    parser.add_argument('--class_name',
                        '-class',
                        type=str,
                        default='all_class',
                        help='class name (default: all_class(str))')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of parallel data loading processes')
    parser.add_argument('--batchsize', '-b', type=int, default=16)
    parser.add_argument('--max_iter', '-m', type=int, default=40000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=2500,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=5000,
                        help='Interval of evaluation')
    parser.add_argument('--out_image_interval',
                        type=int,
                        default=1250,
                        help='Interval of evaluation')
    parser.add_argument('--stage_interval',
                        type=int,
                        default=40000,
                        help='Interval of stage progress')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=1,
        help='number of discriminator update per generator update')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--gamma',
                        type=float,
                        default=750,
                        help='gradient penalty')
    parser.add_argument('--pooling_comp',
                        type=float,
                        default=1.0,
                        help='compensation')
    parser.add_argument('--pretrained_generator', type=str, default="")
    parser.add_argument('--pretrained_discriminator', type=str, default="")
    parser.add_argument('--initial_stage', type=float, default=0.0)
    parser.add_argument('--generator_smoothing', type=float, default=0.999)

    args = parser.parse_args()
    record_setting(args.out)

    report_keys = [
        "stage", "loss_dis", "loss_gp", "loss_gen", "g", "inception_mean",
        "inception_std", "FID"
    ]
    max_iter = args.max_iter

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()

    generator = Generator()
    generator_smooth = Generator()
    discriminator = Discriminator(pooling_comp=args.pooling_comp)

    # select GPU
    if args.gpu >= 0:
        generator.to_gpu()
        generator_smooth.to_gpu()
        discriminator.to_gpu()
        print("use gpu {}".format(args.gpu))

    if args.pretrained_generator != "":
        chainer.serializers.load_npz(args.pretrained_generator, generator)
    if args.pretrained_discriminator != "":
        chainer.serializers.load_npz(args.pretrained_discriminator,
                                     discriminator)
    copy_param(generator_smooth, generator)

    # Setup an optimizer
    def make_optimizer(model, alpha=0.001, beta1=0.0, beta2=0.99):
        optimizer = chainer.optimizers.Adam(alpha=alpha,
                                            beta1=beta1,
                                            beta2=beta2)
        optimizer.setup(model)
        # optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(generator)
    opt_dis = make_optimizer(discriminator)

    if args.class_name == 'all_class':
        data_path = args.data_path
        one_class_flag = False
    else:
        data_path = os.path.join(args.data_path, args.class_name)
        one_class_flag = True

    train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag)
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset, args.batchsize, n_processes=args.num_workers)

    # Set up a trainer
    updater = Updater(models=(generator, discriminator, generator_smooth),
                      iterator={'main': train_iter},
                      optimizer={
                          'opt_gen': opt_gen,
                          'opt_dis': opt_dis
                      },
                      device=args.gpu,
                      n_dis=args.n_dis,
                      lam=args.lam,
                      gamma=args.gamma,
                      smoothing=args.generator_smoothing,
                      initial_stage=args.initial_stage,
                      stage_interval=args.stage_interval)
    trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out)
    trainer.extend(extensions.snapshot_object(
        generator, 'generator_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        generator_smooth, 'generator_smooth_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        discriminator, 'discriminator_{.updater.iteration}.npz'),
                   trigger=(args.snapshot_interval, 'iteration'))

    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator_smooth, args.out),
                   trigger=(args.out_image_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator_smooth, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(generator_smooth),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator_smooth),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
示例#2
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=10000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--image-npz', type=str, default="")
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--resume', type=str, default="")
    parser.add_argument('--ch', type=int, default=512)
    parser.add_argument('--snapshot-iter', type=int, default=0)
    parser.add_argument('--range', type=float, default=1.0)

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ["loss_dis", "loss_gen"]

    # Set up dataset
    #train_dataset = Cifar10Dataset()
    if args.image_npz != '':
        from dataset.cifar10like import NPZColorDataset
        train_dataset = NPZColorDataset(npz=args.image_npz)
    elif args.data_dir != '':
        from dataset.cifar10like import CIFAR10Like
        train_dataset = CIFAR10Like(args.data_dir)
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "lim_dcgan":
            generator = common.net.LimDCGANGenerator(n_hidden=args.n_hidden,
                                                     range=args.range)
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        from stdgan.updater import Updater
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "sndcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.SNDCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "my_cramer":
        from cramer.my_updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "lim_cramer":
        #from cramer.my_updater import Updater
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.LimDCGANGenerator(range=args.range)
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    if args.snapshot_iter == 0:
        snap_iter = args.max_iter // 100
    else:
        snap_iter = args.snapshot_iter
    trainer.extend(extensions.snapshot(), trigger=(snap_iter, 'iteration'))

    # resume
    if args.resume != "":
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
示例#3
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--z_dim', type=int, default=2)
    parser.add_argument('--bound', type=float, default=1.0)
    parser.add_argument('--h_dim', type=int, default=128)
    parser.add_argument('--max_iter', type=int, default=150000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=10000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=1,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--shift_interval',
                        type=float,
                        default=10000,
                        help='Interval of shift annealing')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ["inception_mean", "inception_std", "FID"]
    # report_keys = ["loss_dis", "loss_gen", "ais", "inception_mean", "inception_std", "FID"]

    # Set up dataset
    train_dataset = Cifar10Dataset()
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "gan":
        from gan.net import Discriminator, Denoiser
        from gan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
        updater_args["n_dis"] = args.n_dis
        updater_args["shift_interval"] = args.shift_interval
        updater_args['n_iter'] = args.max_iter

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
            denoiser = common.net.Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        report_keys.append("loss_gp")
        models = [generator, discriminator, denoiser]
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam
        updater_args["shift_interval"] = args.shift_interval
        updater_args['n_iter'] = args.max_iter
    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))

    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
示例#4
0
def main():
    parser = argparse.ArgumentParser(description='Train StarGAN')
    parser.add_argument('--source_path',
                        default="source/celebA/",
                        help="data resource Directory")
    parser.add_argument('--att_list_path',
                        default="att_list.txt",
                        help="attribute list")
    parser.add_argument('--batch_size', '-b', type=int, default=16)
    parser.add_argument('--max_iter', '-m', type=int, default=200000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--eval_folder',
                        '-e',
                        default='test',
                        help='Directory to output the evaluation result')

    parser.add_argument('--eval_interval',
                        type=int,
                        default=1000,
                        help='Interval of evaluating generator')

    parser.add_argument("--learning_rate_g",
                        type=float,
                        default=0.0001,
                        help="Learning rate for generator")
    parser.add_argument("--learning_rate_d",
                        type=float,
                        default=0.0001,
                        help="Learning rate for discriminator")
    parser.add_argument("--load_gen_model",
                        default='',
                        help='load generator model')
    parser.add_argument("--load_dis_model",
                        default='',
                        help='load discriminator model')

    parser.add_argument('--gen_class',
                        default='StarGAN_Generator',
                        help='Default generator class')
    parser.add_argument('--dis_class',
                        default='StarGAN_Discriminator',
                        help='Default discriminator class')

    parser.add_argument("--n_dis",
                        type=int,
                        default=6,
                        help='The number of loop of WGAN Discriminator')
    parser.add_argument("--lambda_gp",
                        type=float,
                        default=10.0,
                        help='lambda for gradient penalty of WGAN')
    parser.add_argument("--lambda_adv",
                        type=float,
                        default=1.0,
                        help='lambda for adversarial loss')
    parser.add_argument("--lambda_cls",
                        type=float,
                        default=1.0,
                        help='lambda for classification loss')
    parser.add_argument("--lambda_rec",
                        type=float,
                        default=10.0,
                        help='lambda for reconstruction loss')

    parser.add_argument("--flip",
                        type=int,
                        default=1,
                        help='flip images for data augmentation')
    parser.add_argument("--resize_to",
                        type=int,
                        default=128,
                        help='resize the image to')
    parser.add_argument("--crop_to",
                        type=int,
                        default=178,
                        help='crop the resized image to')
    parser.add_argument("--load_dataset",
                        default='celebA_train',
                        help='load dataset')
    parser.add_argument("--discriminator_layer_n",
                        type=int,
                        default=6,
                        help='number of discriminator layers')

    parser.add_argument("--learning_rate_anneal",
                        type=float,
                        default=10e-8,
                        help='anneal the learning rate')
    parser.add_argument("--learning_rate_anneal_start",
                        type=int,
                        default=100000,
                        help='time to anneal the learning')

    args = parser.parse_args()
    print(args)
    record_setting(args.out)
    max_iter = args.max_iter

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()

    with open(args.att_list_path, "r") as f:
        att_list = []
        att_name = []
        for line in f:
            line = line.strip().split(" ")
            if len(line) == 3:
                att_list.append(int(line[0]))  #attID
                att_name.append(line[1])  #attname
    print("attribute list:", ",".join(att_name))

    #load dataset
    train_dataset = getattr(celebA,
                            args.load_dataset)(args.source_path,
                                               att_name,
                                               flip=args.flip,
                                               resize_to=args.resize_to,
                                               crop_to=args.crop_to)
    train_iter = chainer.iterators.MultiprocessIterator(train_dataset,
                                                        args.batch_size,
                                                        n_processes=4)

    #test_dataset = getattr(celebA, args.load_dataset)(root_celebA, flip=args.flip, resize_to=args.resize_to, crop_to=args.crop_to)
    test_batchsize = 8
    test_iter = chainer.iterators.SerialIterator(train_dataset, test_batchsize)

    #set generator and discriminator
    nc_size = len(att_list)  #num of attribute
    gen = getattr(net, args.gen_class)(args.resize_to, nc_size)
    dis = getattr(net,
                  args.dis_class)(n_down_layers=args.discriminator_layer_n)

    if args.load_gen_model != '':
        serializers.load_npz(args.load_gen_model, gen)
        print("Generator model loaded")

    if args.load_dis_model != '':
        serializers.load_npz(args.load_dis_model, dis)
        print("Discriminator model loaded")

    if not os.path.exists(args.eval_folder):
        os.makedirs(args.eval_folder)

    # select GPU
    if args.gpu >= 0:
        gen.to_gpu()
        dis.to_gpu()
        print("use gpu {}".format(args.gpu))

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0001, beta1=0.5, beta2=0.999):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        return optimizer

    opt_gen = make_optimizer(gen, alpha=args.learning_rate_g)
    opt_dis = make_optimizer(dis, alpha=args.learning_rate_d)

    # Set up a trainer
    updater = Updater(models=(gen, dis),
                      iterator={
                          'main': train_iter,
                          'test': test_iter
                      },
                      optimizer={
                          'opt_gen': opt_gen,
                          'opt_dis': opt_dis,
                      },
                      device=args.gpu,
                      params={
                          'n_dis': args.n_dis,
                          'lambda_adv': args.lambda_adv,
                          'lambda_cls': args.lambda_cls,
                          'lambda_rec': args.lambda_rec,
                          'lambda_gp': args.lambda_gp,
                          'image_size': args.resize_to,
                          'eval_folder': args.eval_folder,
                          'nc_size': nc_size,
                          'learning_rate_anneal': args.learning_rate_anneal,
                          'learning_rate_anneal_start':
                          args.learning_rate_anneal_start,
                          'dataset': train_dataset
                      })

    model_save_interval = (4000, 'iteration')
    trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out)
    trainer.extend(extensions.snapshot_object(gen,
                                              'gen{.updater.iteration}.npz'),
                   trigger=model_save_interval)
    trainer.extend(extensions.snapshot_object(dis,
                                              'dis{.updater.iteration}.npz'),
                   trigger=model_save_interval)

    log_keys = [
        'epoch', 'iteration', 'lr_g', 'lr_d', 'loss_dis_adv', 'loss_gen_adv',
        'loss_dis_cls', 'loss_gen_cls', 'loss_gen_rec', 'loss_gp'
    ]
    trainer.extend(
        extensions.LogReport(keys=log_keys, trigger=(20, 'iteration')))
    trainer.extend(extensions.PrintReport(log_keys), trigger=(20, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=50))

    trainer.extend(evaluation(gen, args.eval_folder,
                              image_size=args.resize_to),
                   trigger=(args.eval_interval, 'iteration'))

    #trainer.extend(CommandsExtension())
    # Run the training
    trainer.run()
示例#5
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default='wgan_gp_res',
                        help='GAN algorithm')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=60000)
    parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID')
    parser.add_argument('--out',
                        type=str,
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=10000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of display')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha', type=float, default=0.0002)
    parser.add_argument('--adam_beta1', type=float, default=0.0)
    parser.add_argument('--adam_beta2', type=float, default=0.9)

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ['loss_dis', 'loss_gen', 'inception_mean', 'inception_std']

    # set up  dataset
    train_dataset = Cifar10Dataset()
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # set up netwroks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == 'wgan_gp_res':
        from updaters.wgangp_updater import Updater
        import dis_models.resnet_discriminator
        generator = gen_models.resnet_generator.ResnetGenerator()
        discriminator = dis_models.resnet_discriminator.ResnetDiscriminator()
        models = [generator, discriminator]
        report_keys.append('loss_gp')
        updater_args['n_dis'] = args.n_dis
        updater_args['lam'] = args.lam

    elif args.algorithm == 'sngan_res':
        from updaters.stdgan_updater import Updater
        import dis_models.sn_resnet_discriminator
        generator = gen_models.resnet_generator.ResnetGenerator()
        discriminator = dis_models.sn_resnet_discriminator.SNResnetDiscriminator(
        )
        models = [generator, discriminator]
        updater_args['n_dis'] = args.n_dis

    elif args.algorithm == 'snwgan_res':
        from updaters.wgan_like_updater import Updater
        import dis_models.sn_resnet_discriminator
        generator = gen_models.resnet_generator.ResnetGenerator()
        discriminator = dis_models.sn_resnet_discriminator.SNResnetDiscriminator(
        )
        models = [generator, discriminator]
        updater_args['n_dis'] = args.n_dis

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        for m in models:
            m.to_gpu()

    # set up optimizers
    opts['opt_gen'] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts['opt_dis'] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    updater_args['optimizer'] = opts
    updater_args['models'] = models

    # set up updater
    updater = Updater(**updater_args)

    # set up trainer
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(extensions.PlotReport(['loss_dis', 'loss_gen'],
                                         'iteration',
                                         trigger=(args.display_interval,
                                                  'iteration'),
                                         file_name='loss.png'),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate8(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    #     trainer.extend(calc_inception(generator, batchsize=100, n_ims=1000),
    #                        trigger=(args.evaluation_interval, 'iteration'),
    #                        priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # train
    trainer.run()
示例#6
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--data_path',
                        type=str,
                        default="data/datasets/cifar10/train",
                        help='dataset directory path')
    parser.add_argument('--class_name',
                        '-class',
                        type=str,
                        default='all_class',
                        help='class name (default: all_class(str))')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of parallel data loading processes')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=10000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = [
        "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID"
    ]

    # Set up dataset
    if args.class_name == 'all_class':
        data_path = args.data_path
        one_class_flag = False
    else:
        data_path = os.path.join(args.data_path, args.class_name)
        one_class_flag = True

    train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag)
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset, args.batchsize, n_processes=args.num_workers)
    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        from stdgan.updater import Updater
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "sndcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.SNDCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
示例#7
0
parser.add_argument('--out', '-o', default='result_classifier',
                    help='Directory to output the result')
parser.add_argument('--num_syms', '-n', type=int, default=3,
                    help='Number of symbols to demod at a Time')
parser.add_argument('--model_type', '-t', type=str, default="AlexStock",
                    help='Which Model to run (AlexStock, ComplexNN)')
parser.add_argument('--snr', '-s', type=int, default=18,
                    help='SNR to use for demodulation training')
args = parser.parse_args()


# results_output_dir = os.path.join(os.environ['KUNGLAB_SHARE_RESULTS'], args.out) 
results_output_dir = args.out
if not os.path.exists(results_output_dir):
    os.makedirs(results_output_dir)
record_setting(results_output_dir)

snr = range(-8,6,2)
print snr
num_syms = args.num_syms 
data_train = DemodSNRDataset(test=False, snr=snr, num_syms=num_syms)
data_test = DemodSNRDataset(test=True, snr=snr, num_syms=num_syms)
num_classes = np.unique(data_train.ys).shape[0] 


# train model
if args.model_type == "AlexStock" or args.model_type == "AlexSmall":
    print "AlexSmall"
    model = L.Classifier(model_map[args.model_type](num_classes, init_weights=True, filter_height=2))
else:
    model = L.Classifier(model_map[args.model_type](num_classes, init_weights=True, filter_height=1))
示例#8
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--dataset',
                        type=str,
                        default="cifar10",
                        help='Dataset')
    parser.add_argument('--bottom_width', type=int, default=4)
    parser.add_argument('--udvmode', type=int, default=1)
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=5000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=5000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = [
        "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID",
        "loss_orth"
    ]

    # Set up dataset
    if args.dataset == "cifar10":
        train_dataset = Cifar10Dataset()
    elif args.dataset == "stl10":
        train_dataset = STL10Dataset()
        args.bottom_width = 6
    else:
        raise NotImplementedError()
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            from stdgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.DCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "sndcgan":
            from stdgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.SNDCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "snresdcgan":
            from stdgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.SNResnetDiscriminator(
                bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "orthgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "orthdcgan":
            from orthgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.ORTHDCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "orthresdcgan":
            from orthgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.ORTHResnetDiscriminator(
                bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "uvgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "uvdcgan":
            from orthgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.UVDCGANDiscriminator(
                args.udvmode, bottom_width=args.bottom_width)
        elif args.architecture == "uvresdcgan":
            from orthgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.UVResnetDiscriminator(
                args.udvmode, bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    if "res" not in args.architecture:
        trainer.extend(sv_generate(discriminator, args.out),
                       trigger=(args.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    IS_array = []
    FID_array = []
    trainer.extend(calc_inception(generator, IS_array, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator, FID_array, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
示例#9
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result', help='Directory to output the result')
    parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot')
    parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation')
    parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console')
    parser.add_argument('--n_dis', type=int, default=1, help='number of discriminator update per generator update') # 5
    parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma')
    parser.add_argument('--lam', type=float, default=10, help='gradient penalty')
    parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1', type=float, default=0.5, help='beta1 in Adam optimizer') # 0.0
    parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') # 0.9
    parser.add_argument('--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)')
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--image-npz', type=str, default="")
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--resume', type=str, default="")
    parser.add_argument('--ch', type=int, default=512)
    parser.add_argument('--snapshot-iter', type=int, default=0)

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ["loss_dis", "loss_gen"]

    # Set up dataset
    if args.image_npz != '':
        from c128dcgan.dataset import NPZColorDataset
        train_dataset = NPZColorDataset(npz=args.image_npz)
    elif args.data_dir != '':
        from c128dcgan.dataset import Color128x128Dataset
        train_dataset = Color128x128Dataset(args.data_dir)
    train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {
        "iterator": {'main': train_iter},
        "device": args.gpu
    }

    # fixed algorithm
    #from c128gan import Updater
    generator = common.net.C128Generator(ch=args.ch, n_hidden=args.n_hidden)
    discriminator = common.net.SND128Discriminator(ch=args.ch)
    models = [generator, discriminator]
    from dcgan.updater import Updater

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    if args.snapshot_iter == 0:
        snap_iter= args.max_iter // 100
    else:
        snap_iter = args.snapshot_iter
    trainer.extend(extensions.snapshot(), trigger=(snap_iter , 'iteration'))

    # resume
    if args.resume != "":
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()