コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('image_dir', type=str)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--noise_dim', '-nd', type=int, default=100)
    parser.add_argument('--height', '-ht', type=int, default=128)
    parser.add_argument('--width', '-wd', type=int, default=128)
    parser.add_argument('--save_steps', '-ss', type=int, default=1)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=1)
    parser.add_argument('--logdir', '-ld', type=str, default="../logs")
    parser.add_argument('--noise_mode', '-nm', type=str, default="uniform")
    parser.add_argument('--upsampling', '-up', type=str, default="deconv")
    parser.add_argument('--metrics', '-m', type=str, default="JSD")
    parser.add_argument('--lr_d', type=float, default=1e-4)
    parser.add_argument('--lr_g', type=float, default=1e-4)
    parser.add_argument('--norm_d', type=str, default=None)
    parser.add_argument('--norm_g', type=str, default=None)
    parser.add_argument('--model', type=str, default='residual')

    args = parser.parse_args()

    # output config to csv
    args_to_csv(os.path.join(args.logdir, 'config.csv'), args)

    input_shape = (args.height, args.width, 3)

    image_sampler = ImageSampler(args.image_dir,
                                 target_size=(args.width, args.height))
    noise_sampler = NoiseSampler(args.noise_mode)

    if args.model == 'residual':
        generator = ResidualGenerator(args.noise_dim,
                                      target_size=(args.width, args.height),
                                      upsampling=args.upsampling,
                                      normalization=args.norm_g)
        discriminator = ResidualDiscriminator(input_shape,
                                              normalization=args.norm_d)
    elif args.model == 'plane':
        generator = Generator(args.noise_dim,
                              upsampling=args.upsampling,
                              normalization=args.norm_g)
        discriminator = Discriminator(input_shape,
                                      normalization=args.norm_d)
    else:
        raise ValueError

    gan = GAN(generator,
              discriminator,
              metrics=args.metrics,
              lr_d=args.lr_d,
              lr_g=args.lr_g)

    gan.fit(image_sampler.flow_from_directory(args.batch_size),
            noise_sampler,
            nb_epoch=args.nb_epoch,
            logdir=args.logdir,
            save_steps=args.save_steps,
            visualize_steps=args.visualize_steps)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('x_dir', type=str)
    parser.add_argument('y_dir', type=str)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--height', '-ht', type=int, default=256)
    parser.add_argument('--width', '-wd', type=int, default=256)
    parser.add_argument('--save_steps', '-ss', type=int, default=10)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=10)
    parser.add_argument('--gp_weight', '-gp', type=float, default=10.)
    parser.add_argument('--l1_weight', '-l1', type=float, default=1.)
    parser.add_argument('--initial_steps', '-is', type=int, default=20)
    parser.add_argument('--initial_critics', '-ic', type=int, default=20)
    parser.add_argument('--normal_critics', '-nc', type=int, default=5)
    parser.add_argument('--model_dir', '-md', type=str, default="./params")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")
    parser.add_argument('--noise_mode', '-nm', type=str, default="uniform")
    parser.add_argument('--upsampling', '-up', type=str, default="deconv")
    parser.add_argument('--dis_norm', '-dn', type=str, default=None)

    args = parser.parse_args()

    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    image_sampler = ImageSampler(target_size=(args.width, args.height),
                                 color_mode_x='rgb',
                                 color_mode_y='rgb',
                                 normalization_x='tanh',
                                 normalization_y='tanh',
                                 is_flip=False)

    generator = UNet((args.height, args.width, 3),
                     color_mode='rgb',
                     upsampling=args.upsampling,
                     is_training=True)

    discriminator = ResidualDiscriminator((args.height, args.width, 6),
                                          normalization=args.dis_norm,
                                          is_training=True)

    pix2pix = Pix2Pix(generator,
                      discriminator,
                      l1_weight=args.l1_weight,
                      gradient_penalty_weight=args.gp_weight,
                      is_training=True)
    pix2pix.fit(image_sampler.flow_from_directory(args.x_dir,
                                                  args.y_dir,
                                                  batch_size=args.batch_size),
                result_dir=args.result_dir,
                model_dir=args.model_dir,
                save_steps=args.save_steps,
                visualize_steps=args.visualize_steps,
                nb_epoch=args.nb_epoch,
                initial_steps=args.initial_steps,
                initial_critics=args.initial_critics,
                normal_critics=args.normal_critics)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('train_dir')
    parser.add_argument('val_dir')
    parser.add_argument('--nb_classes', '-nc', type=int, default=2)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=100)
    parser.add_argument('--height', '-ht', type=int, default=32)
    parser.add_argument('--width', '-wd', type=int, default=32)
    parser.add_argument('--save_steps', '-ss', type=int, default=10)
    parser.add_argument('--validation_steps', '-vs', type=int, default=10)
    parser.add_argument('--logdir', '-ld', default='../logs')

    args = parser.parse_args()

    image_sampler = ImageSampler(target_size=(args.width, args.height),
                                 color_mode='rgb',
                                 normalize_mode='sigmoid')

    val_sampler = ImageSampler(target_size=(args.width, args.height),
                               color_mode='rgb',
                               normalize_mode='sigmoid',
                               is_training=False)

    model = CifarCNN((args.width, args.height, 3),
                     nb_classes=args.nb_classes,
                     logdir=args.logdir)

    model.fit_generator(
        image_sampler.flow_from_directory(args.train_dir,
                                          batch_size=args.batch_size,
                                          with_class=True),
        val_sampler.flow_from_directory(args.val_dir,
                                        batch_size=args.batch_size,
                                        with_class=True,
                                        shuffle=False),
        nb_epoch=args.nb_epoch,
        validation_steps=args.validation_steps,
        save_steps=args.save_steps,
        model_dir=args.logdir)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('image_dir', type=str)
    parser.add_argument('--channel', '-c', type=int, default=3)
    parser.add_argument('--nb_growing', '-g', type=int, default=8)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--latent_dim', '-ld', type=int, default=100)
    parser.add_argument('--save_steps', '-ss', type=int, default=1)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=1)
    parser.add_argument('--logdir', '-log', type=str, default="../logs")
    parser.add_argument('--distribution', '-dis', type=str, default="uniform")
    parser.add_argument('--upsampling', '-up', type=str, default="deconv")
    parser.add_argument('--downsampling', '-down', type=str, default="stride")
    parser.add_argument('--lr_d', type=float, default=1e-4)
    parser.add_argument('--lr_g', type=float, default=1e-4)
    parser.add_argument('--gp_lambda', '-lmd', type=float, default=10.)
    parser.add_argument('--d_norm_eps', '-eps', type=float, default=1e-3)

    args = parser.parse_args()

    args_to_csv(os.path.join(args.logdir, 'config.csv'), args)

    image_sampler = ImageSampler()
    noise_sampler = NoiseSampler(args.distribution)

    model = PGGAN(channel=args.channel,
                  latent_dim=args.latent_dim,
                  nb_growing=args.nb_growing,
                  gp_lambda=args.gp_lambda,
                  d_norm_eps=args.d_norm_eps,
                  upsampling=args.upsampling,
                  downsampling=args.downsampling,
                  lr_d=args.lr_d,
                  lr_g=args.lr_g)
    model.fit(image_sampler.flow_from_directory(args.image_dir,
                                                args.batch_size,
                                                with_class=False),
              noise_sampler,
              nb_epoch=args.nb_epoch,
              logdir=args.logdir,
              save_steps=args.save_steps,
              visualize_steps=args.visualize_steps)
コード例 #5
0
ファイル: generate.py プロジェクト: rearwist3/aae_solder_tf
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('ok_image_dir', type=str)
    parser.add_argument('ng_image_dir', type=str)
    parser.add_argument('--batch_size', '-bs', type=int, default=234)
    parser.add_argument('--latent_dim', '-ld', type=int, default=16)
    parser.add_argument('--height', '-ht', type=int, default=64)
    parser.add_argument('--width', '-wd', type=int, default=64)
    parser.add_argument('--channel', '-ch', type=int, default=8)
    parser.add_argument('--model_path',
                        '-mp',
                        type=str,
                        default="./params/epoch_200/model.ckpt")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")
    parser.add_argument('--nb_visualize_batch', '-nvb', type=int, default=1)
    parser.add_argument('--select_gpu', '-sg', type=str, default="0")

    args = parser.parse_args()
    os.makedirs(args.result_dir, exist_ok=True)

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        visible_device_list=args.select_gpu,  # specify GPU number
        allow_growth=True))

    input_shape = (args.height, args.width, args.channel)

    autoencoder = AutoEncoder(input_shape,
                              args.latent_dim,
                              is_training=False,
                              channel=args.channel)
    discriminator = Discriminator(is_training=False)

    aae = AAE(autoencoder, discriminator, is_training=False)
    aae.restore(args.model_path)

    result_dir_inlier = os.path.join(args.result_dir, "decoded/inlier")
    result_dir_outlier = os.path.join(args.result_dir, "decoded/outlier")

    image_sampler = ImageSampler(
        target_size=(args.width, args.height),
        color_mode='rgb' if args.channel == 3 else 'gray',
        is_training=False)

    data_generator_inlier = image_sampler.flow_from_directory(args.inlier_dir,
                                                              args.batch_size,
                                                              shuffle=False)
    df_inlier = get_encoded_save_decoded(aae,
                                         data_generator_inlier,
                                         args.latent_dim,
                                         result_dir_inlier,
                                         label='inlier',
                                         nb_visualize=args.nb_visualize_batch)

    data_generator_outlier = image_sampler.flow_from_directory(
        args.outlier_dir, args.batch_size, shuffle=False)
    df_outlier = get_encoded_save_decoded(aae,
                                          data_generator_outlier,
                                          args.latent_dim,
                                          result_dir_outlier,
                                          label='outlier',
                                          nb_visualize=args.nb_visualize_batch)

    df = pd.concat([df_inlier, df_outlier], ignore_index=True)
    os.makedirs(args.result_dir, exist_ok=True)
    df.to_csv(os.path.join(args.result_dir, "output.csv"), index=False)