Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=10000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--image-npz', type=str, default="")
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--resume', type=str, default="")
    parser.add_argument('--ch', type=int, default=512)
    parser.add_argument('--snapshot-iter', type=int, default=0)
    parser.add_argument('--range', type=float, default=1.0)

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ["loss_dis", "loss_gen"]

    # Set up dataset
    #train_dataset = Cifar10Dataset()
    if args.image_npz != '':
        from dataset.cifar10like import NPZColorDataset
        train_dataset = NPZColorDataset(npz=args.image_npz)
    elif args.data_dir != '':
        from dataset.cifar10like import CIFAR10Like
        train_dataset = CIFAR10Like(args.data_dir)
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "lim_dcgan":
            generator = common.net.LimDCGANGenerator(n_hidden=args.n_hidden,
                                                     range=args.range)
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        from stdgan.updater import Updater
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "sndcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.SNDCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "my_cramer":
        from cramer.my_updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "lim_cramer":
        #from cramer.my_updater import Updater
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.LimDCGANGenerator(range=args.range)
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    if args.snapshot_iter == 0:
        snap_iter = args.max_iter // 100
    else:
        snap_iter = args.snapshot_iter
    trainer.extend(extensions.snapshot(), trigger=(snap_iter, 'iteration'))

    # resume
    if args.resume != "":
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--data_path',
                        type=str,
                        default="data/datasets/cifar10/train",
                        help='dataset directory path')
    parser.add_argument('--class_name',
                        '-class',
                        type=str,
                        default='all_class',
                        help='class name (default: all_class(str))')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of parallel data loading processes')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=10000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = [
        "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID"
    ]

    # Set up dataset
    if args.class_name == 'all_class':
        data_path = args.data_path
        one_class_flag = False
    else:
        data_path = os.path.join(args.data_path, args.class_name)
        one_class_flag = True

    train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag)
    train_iter = chainer.iterators.MultiprocessIterator(
        train_dataset, args.batchsize, n_processes=args.num_workers)
    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        from stdgan.updater import Updater
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        elif args.architecture == "sndcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.SNDCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_inception(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
Exemplo n.º 3
0
Arquivo: train.py Projeto: min9813/GAN
def train(step=(5, 100),
          data_set='cifar10',
          debug=True,
          n_hidden=128,
          batch_size=128,
          save_snapshot=False,
          epoch_intervel=(1, "epoch"),
          dispplay_interval=(100, "iteration"),
          out_folder="./wgan_mnist/",
          max_time=(1, "epoch"),
          out_image_edge_num=100,
          is_cgan=False,
          gamma=0.5,
          perturb_weight=0.5,
          method="wgangp",
          techniques=None):

    check_and_make_dir(out_folder)
    new_folder = "folder_{}".format(len(os.listdir(out_folder)))
    out_folder = os.path.join(out_folder, new_folder)
    if debug:
        max_time = (1000, "iteration")
    # Make a specified GPU current

    updater_args = {"n_dis": step, "device": 0}

    if data_set == 'cifar10':
        print("use cifar dataset")
        # ndim=3 : (ch,width,height)
        train = dataset.Cifar10Dataset(is_need_conditional=is_cgan)
        train_iter = chainer.iterators.SerialIterator(train, batch_size)
    elif data_set == "mnist":
        print("use mnist dataset")
        # Load the MNIST dataset
        # ndim=3 : (ch,width,height)
        train = dataset.Mnist10Dataset(is_need_conditional=is_cgan)
        train_iter = chainer.iterators.SerialIterator(train,
                                                      batch_size,
                                                      shuffle=False)
    elif data_set == "toy":
        if is_cgan:
            raise NotImplementedError
        train = gaussian_mixture_circle(60000, std=0.1)
        train_iter = chainer.iterators.SerialIterator(train, batch_size)
    else:
        sys.exit("data_set argument must be next argument [{}]".format(
            "'cifar10','mnist','toy'"))

    # Setup an optimizer
    def make_optimizer(model, **params):
        if method == "dcgan":
            # parametor require 'alpha','beta1','beta2'
            optimizer = chainer.optimizers.Adam(alpha=params["alpha"],
                                                beta1=params["beta1"])
            optimizer.setup(model)
            # optimizer.add_hook(
            # chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        elif method == "wgan":
            optimizer = chainer.optimizers.RMSprop(lr=params["lr"])
            optimizer.setup(model)
            try:
                optimizer.add_hook(WeightClipping(params["clip"]))
            except KeyError:
                pass
        elif method in ("sngan", "wgangp", "began", "cramer", "dragan",
                        "improve_technique"):
            optimizer = chainer.optimizers.Adam(alpha=params["alpha"],
                                                beta1=params["beta1"],
                                                beta2=params["beta2"])
            optimizer.setup(model)
        else:
            raise NotImplementedError

        return optimizer

    if method == "dcgan":
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = common.net.Discriminator()
        elif data_set == 'cifar10':
            gen = common.net.CifarGenerator(n_hidden)
            dis = common.net.CifarDiscriminator()
        else:
            gen = common.net.FCGenerator(n_hidden)
            dis = common.net.FCDiscriminator()
        from dcgan.updater import Updater
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0.5)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0.5)

        plot_report = ["gen/loss", "dis/loss"]
        print_report = plot_report

    elif method == "wgan":
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = common.net.Discriminator()
        elif data_set == 'cifar10':
            gen = common.net.CifarGenerator(n_hidden)
            dis = common.net.CifarDiscriminator()
        else:
            gen = common.net.FCGenerator(n_hidden)
            dis = common.net.FCDiscriminator()
        from wgan.updater import Updater
        opt_gen = make_optimizer(gen, lr=5e-5)
        opt_dis = make_optimizer(dis, lr=5e-5, clip=0.01)
        plot_report = ["gen/loss", 'wasserstein distance']
        print_report = plot_report

    elif method == "wgangp":
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = common.net.Discriminator()
        elif data_set == 'cifar10':
            gen = common.net.CifarGenerator(n_hidden)
            dis = common.net.WGANDiscriminator()
        else:
            gen = common.net.FCGenerator(n_hidden)
            dis = common.net.FCDiscriminator()
        from wgangp.updater import Updater, CGANUpdater
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0, beta2=0.9)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0, beta2=0.9)

        updater_args["gradient_penalty_weight"] = GRADIENT_PENALTY_WEIGHT

        plot_report = ["gen/loss", 'wasserstein distance']
        print_report = plot_report + ["critic/loss_grad", "critic/loss"]
    elif method == "began":
        import began
        from began.updater import Updater
        if data_set == "mnist":
            gen = began.net.MnistGenerator(n_hidden)
            dis = began.net.MnistDiscriminator()
        elif data_set == 'cifar10':
            gen = began.net.CifarGenerator(n_hidden)
            dis = began.net.CifarDiscriminator()
        updater_args["gamma"] = gamma
        updater_args["lambda_k"] = 0.001
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0, beta2=0.9)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0, beta2=0.9)
        plot_report = ["dis/loss", "gen/loss"]
        print_report = plot_report + ["kt", "measurement"]
    elif method == "dragan":
        from dragan.updater import Updater
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = common.net.Discriminator()
        elif data_set == 'cifar10':
            gen = common.net.CifarGenerator(n_hidden)
            dis = common.net.WGANDiscriminator()
        updater_args["gradient_penalty_weight"] = GRADIENT_PENALTY_WEIGHT
        updater_args["perturb_weight"] = perturb_weight
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0, beta2=0.9)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0, beta2=0.9)
        plot_report = ["gen/loss", 'dis/loss']
        print_report = plot_report + ["dis/loss_grad"]
    elif method == "improve_technique":
        import improve_technique.net
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = improve_technique.net.MnistMinibatchDiscriminator(
                use_feature_matching=techniques["feature_matching"])
        elif data_set == 'cifar10':
            gen = common.net.Discriminator(n_hidden)
            dis = improve_technique.net.CifarDeepMinibatchDiscriminator(
                use_feature_matching=techniques["feature_matching"])
        else:
            raise NotImplementedError
        if techniques["feature_matching"]:
            print("**feature matching**")
            from improve_technique.updater import Updater
            plot_report = ['dis/loss', "gen/loss", "gen/loss_feature"]
            print_report = plot_report
        else:
            from dcgan.updater import Updater
            plot_report = ['dis/loss', "gen/loss"]
            print_report = plot_report
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0, beta2=0.9)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0, beta2=0.9)
    elif method == "sngan":
        from sngan.updater import Updater
        if data_set == "mnist":
            gen = common.net.Generator(n_hidden)
            dis = common.net.SNMnistDiscriminator()
        elif data_set == 'cifar10':
            gen = common.net.CifarGenerator(n_hidden)
            dis = common.net.SNCifarDiscriminator()
        else:
            raise NotImplementedError
        opt_gen = make_optimizer(gen, alpha=0.0002, beta1=0, beta2=0.9)
        opt_dis = make_optimizer(dis, alpha=0.0002, beta1=0, beta2=0.9)
        plot_report = ["gen/loss", 'dis/loss']
        print_report = plot_report

    else:
        raise NotImplementedError

    gen.to_gpu()  # Copy the model to the GPU
    dis.to_gpu()

    models = {"gen": gen, "dis": dis}
    opt = {"gen": opt_gen, "dis": opt_dis}
    updater_args["optimizer"] = opt
    updater_args["models"] = models
    updater_args["iterator"] = train_iter
    fixed_noise = cupy.random.uniform(
        -1, 1, (out_image_edge_num**2, n_hidden, 1, 1)).astype("f")
    if is_cgan:
        label_num = train.class_label_num
        updater_args["class_num"] = label_num
        updater = CGANUpdater(**updater_args)
        one_hot_label = cupy.eye(label_num)[cupy.arange(label_num)][:, :, None,
                                                                    None]
        one_hot_label = cupy.concatenate([one_hot_label] * 10)
        fixed_noise = cupy.concatenate([fixed_noise, one_hot_label],
                                       axis=1).astype("f")
        print(fixed_noise.shape)

    updater = Updater(**updater_args)
    # Set up a trainer
    trainer = training.Trainer(updater, stop_trigger=max_time, out=out_folder)

    # epoch_interval = (1, 'epoch')
    save_snapshot_interval = (10000, "iteration")
    display_interval = (100, 'iteration')

    out_image_folder = os.path.join(out_folder, "preview")
    check_and_make_dir(out_image_folder)

    # trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    if method == "wgangp" or method == "wgan":
        trainer.extend(extensions.dump_graph('critic/loss'))
    if save_snapshot:
        trainer.extend(extensions.snapshot_object(
            gen, 'gen_epoch_{.updater.iteration}.npz'),
                       trigger=save_snapshot_interval)
        trainer.extend(extensions.snapshot_object(
            dis, 'dis_epoch_{.updater.iteration}.npz'),
                       trigger=save_snapshot_interval)
    trainer.extend(
        extensions.PlotReport(plot_report,
                              x_key='iteration',
                              file_name='loss.png',
                              trigger=display_interval))
    if method == "began":
        trainer.extend(
            extensions.PlotReport(["measurement"],
                                  x_key='iteration',
                                  file_name='convergence_measure.png',
                                  trigger=display_interval))
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport(['iteration', 'elapsed_time'] +
                                          print_report),
                   trigger=display_interval)
    trainer.extend(out_generated_image(gen, out_image_edge_num,
                                       out_image_edge_num, out_image_folder,
                                       fixed_noise),
                   trigger=(200, "iteration"))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default="dcgan",
                        help='GAN algorithm')
    parser.add_argument('--architecture',
                        type=str,
                        default="dcgan",
                        help='Network architecture')
    parser.add_argument('--dataset',
                        type=str,
                        default="cifar10",
                        help='Dataset')
    parser.add_argument('--bottom_width', type=int, default=4)
    parser.add_argument('--udvmode', type=int, default=1)
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=5000,
                        help='Interval of snapshot')
    parser.add_argument('--evaluation_interval',
                        type=int,
                        default=5000,
                        help='Interval of evaluation')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument(
        '--n_dis',
        type=int,
        default=5,
        help='number of discriminator update per generator update')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help='hyperparameter gamma')
    parser.add_argument('--lam',
                        type=float,
                        default=10,
                        help='gradient penalty')
    parser.add_argument('--adam_alpha',
                        type=float,
                        default=0.0002,
                        help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1',
                        type=float,
                        default=0.0,
                        help='beta1 in Adam optimizer')
    parser.add_argument('--adam_beta2',
                        type=float,
                        default=0.9,
                        help='beta2 in Adam optimizer')
    parser.add_argument(
        '--output_dim',
        type=int,
        default=256,
        help='output dimension of the discriminator (for cramer GAN)')

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = [
        "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID",
        "loss_orth"
    ]

    # Set up dataset
    if args.dataset == "cifar10":
        train_dataset = Cifar10Dataset()
    elif args.dataset == "stl10":
        train_dataset = STL10Dataset()
        args.bottom_width = 6
    else:
        raise NotImplementedError()
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {"iterator": {'main': train_iter}, "device": args.gpu}

    if args.algorithm == "dcgan":
        from dcgan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.DCGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "stdgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "dcgan":
            from stdgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.DCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "sndcgan":
            from stdgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.SNDCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "snresdcgan":
            from stdgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.SNResnetDiscriminator(
                bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "orthgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "orthdcgan":
            from orthgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.ORTHDCGANDiscriminator(
                bottom_width=args.bottom_width)
        elif args.architecture == "orthresdcgan":
            from orthgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.ORTHResnetDiscriminator(
                bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "uvgan":
        updater_args["n_dis"] = args.n_dis
        if args.architecture == "uvdcgan":
            from orthgan.updater import Updater
            generator = common.net.DCGANGenerator(
                bottom_width=args.bottom_width)
            discriminator = common.net.UVDCGANDiscriminator(
                args.udvmode, bottom_width=args.bottom_width)
        elif args.architecture == "uvresdcgan":
            from orthgan.updater import HingeUpdater as Updater
            generator = common.net.ResnetGenerator(
                n_hidden=256, bottom_width=args.bottom_width)
            discriminator = common.net.UVResnetDiscriminator(
                args.udvmode, bottom_width=args.bottom_width)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
    elif args.algorithm == "dfm":
        from dfm.net import Discriminator, Denoiser
        from dfm.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
            denoiser = Denoiser()
        else:
            raise NotImplementedError()
        opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha,
                                         args.adam_beta1, args.adam_beta2)
        report_keys.append("loss_den")
        models = [generator, discriminator, denoiser]
    elif args.algorithm == "minibatch_discrimination":
        from minibatch_discrimination.net import Discriminator
        from minibatch_discrimination.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]

    elif args.algorithm == "began":
        from began.net import Discriminator
        from began.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator(use_bn=False)
            discriminator = Discriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("kt")
        report_keys.append("measure")
        updater_args["gamma"] = args.gamma

    elif args.algorithm == "cramer":
        from cramer.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator(
                output_dim=args.output_dim)
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "dragan":
        from dragan.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    elif args.algorithm == "wgan_gp":
        from wgan_gp.updater import Updater
        if args.architecture == "dcgan":
            generator = common.net.DCGANGenerator()
            discriminator = common.net.WGANDiscriminator()
        else:
            raise NotImplementedError()
        models = [generator, discriminator]
        report_keys.append("loss_gp")
        updater_args["n_dis"] = args.n_dis
        updater_args["lam"] = args.lam

    else:
        raise NotImplementedError()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha,
                                     args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'),
                       trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(
        extensions.LogReport(keys=report_keys,
                             trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out),
                   trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    if "res" not in args.architecture:
        trainer.extend(sv_generate(discriminator, args.out),
                       trigger=(args.evaluation_interval, 'iteration'),
                       priority=extension.PRIORITY_WRITER)
    IS_array = []
    FID_array = []
    trainer.extend(calc_inception(generator, IS_array, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(calc_FID(generator, FID_array, args.out),
                   trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Run the training
    trainer.run()
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='Train script')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--max_iter', type=int, default=100000)
    parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result', help='Directory to output the result')
    parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot')
    parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation')
    parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console')
    parser.add_argument('--n_dis', type=int, default=1, help='number of discriminator update per generator update') # 5
    parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma')
    parser.add_argument('--lam', type=float, default=10, help='gradient penalty')
    parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer')
    parser.add_argument('--adam_beta1', type=float, default=0.5, help='beta1 in Adam optimizer') # 0.0
    parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') # 0.9
    parser.add_argument('--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)')
    parser.add_argument('--data-dir', type=str, default="")
    parser.add_argument('--image-npz', type=str, default="")
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--resume', type=str, default="")
    parser.add_argument('--ch', type=int, default=512)
    parser.add_argument('--snapshot-iter', type=int, default=0)

    args = parser.parse_args()
    record_setting(args.out)
    report_keys = ["loss_dis", "loss_gen"]

    # Set up dataset
    if args.image_npz != '':
        from c128dcgan.dataset import NPZColorDataset
        train_dataset = NPZColorDataset(npz=args.image_npz)
    elif args.data_dir != '':
        from c128dcgan.dataset import Color128x128Dataset
        train_dataset = Color128x128Dataset(args.data_dir)
    train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize)

    # Setup algorithm specific networks and updaters
    models = []
    opts = {}
    updater_args = {
        "iterator": {'main': train_iter},
        "device": args.gpu
    }

    # fixed algorithm
    #from c128gan import Updater
    generator = common.net.C128Generator(ch=args.ch, n_hidden=args.n_hidden)
    discriminator = common.net.SND128Discriminator(ch=args.ch)
    models = [generator, discriminator]
    from dcgan.updater import Updater

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        print("use gpu {}".format(args.gpu))
        for m in models:
            m.to_gpu()

    # Set up optimizers
    opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2)
    opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2)

    updater_args["optimizer"] = opts
    updater_args["models"] = models

    # Set up updater and trainer
    updater = Updater(**updater_args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out)

    # Set up logging
    for m in models:
        trainer.extend(extensions.snapshot_object(
            m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration'))
    trainer.extend(extensions.LogReport(keys=report_keys,
                                        trigger=(args.display_interval, 'iteration')))
    trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration'))
    trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'),
                   priority=extension.PRIORITY_WRITER)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    if args.snapshot_iter == 0:
        snap_iter= args.max_iter // 100
    else:
        snap_iter = args.snapshot_iter
    trainer.extend(extensions.snapshot(), trigger=(snap_iter , 'iteration'))

    # resume
    if args.resume != "":
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()