py.arg('--adversarial_loss_mode', default='gan', choices=['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan'])
py.arg('--gradient_penalty_mode', default='none', choices=['none', 'dragan', 'wgan-gp'])
py.arg('--gradient_penalty_weight', type=float, default=10.0)
py.arg('--experiment_name', default='none')
args = py.args()

# output_dir
if args.experiment_name == 'none':
    args.experiment_name = '%s_%s' % (args.dataset, args.adversarial_loss_mode)
    if args.gradient_penalty_mode != 'none':
        args.experiment_name += '_%s' % args.gradient_penalty_mode
output_dir = py.join('output', '%s_BN%d_DPG%d' % (args.experiment_name, args.batch_size, args.n_d ) )
py.mkdir(output_dir)

# save settings
py.args_to_yaml(py.join(output_dir, 'settings.yml'), args)


# ==============================================================================
# =                               data and model                               =
# ==============================================================================

# setup dataset
if args.dataset in ['cifar10', 'fashion_mnist', 'mnist']:  # 32x32
    dataset, shape, len_dataset = data.make_32x32_dataset(args.dataset, args.batch_size)
    n_G_upsamplings = n_D_downsamplings = 3

elif args.dataset == 'celeba':  # 64x64
    img_paths = py.glob('data/img_align_celeba', '*.jpg')
    dataset, shape, len_dataset = data.make_celeba_dataset(img_paths, args.batch_size)
    n_G_upsamplings = n_D_downsamplings = 4
示例#2
0
    def __init__(
            self,
            epochs=200,
            epoch_decay=100,
            pool_size=50,
            output_dir='output',
            datasets_dir="datasets",
            dataset="drawing",
            image_ext="png",
            crop_size=256,
            load_size=286,
            batch_size=0,
            adversarial_loss_mode="lsgan",  # ['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan']
            lr=0.0002,
            gradient_penalty_mode='none',  # ['none', 'dragan', 'wgan-gp'])
            gradient_penalty_weight=10.0,
            cycle_loss_weight=0.0,
            identity_loss_weight=0.0,
            beta_1=0.5,
            color_depth=1,
            progrssive=False):
        logging.config.fileConfig(fname='log.conf')
        self.logger = logging.getLogger('dev')

        if batch_size == 0:
            batch_size = 1  # later figure out what to do
        epoch_decay = min(epoch_decay, epochs // 2)

        self.output_dataset_dir = py.join(output_dir, dataset)
        py.mkdir(self.output_dataset_dir)
        py.args_to_yaml(
            py.join(self.output_dataset_dir, 'settings.yml'),
            Namespace(
                epochs=epochs,
                epoch_decay=epoch_decay,
                pool_size=pool_size,
                output_dir=output_dir,
                datasets_dir=datasets_dir,
                dataset=dataset,
                image_ext=image_ext,
                crop_size=crop_size,
                load_size=load_size,
                batch_size=batch_size,
                adversarial_loss_mode=
                adversarial_loss_mode,  # ['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan']
                lr=lr,
                gradient_penalty_mode=
                gradient_penalty_mode,  # ['none', 'dragan', 'wgan-gp'])
                gradient_penalty_weight=gradient_penalty_weight,
                cycle_loss_weight=cycle_loss_weight,
                identity_loss_weight=identity_loss_weight,
                beta_1=beta_1,
                color_depth=color_depth,
                progressive=progrssive))
        self.sample_dir = py.join(self.output_dataset_dir, 'samples_training')
        py.mkdir(self.sample_dir)

        self.epochs = epochs
        self.epoch_decay = epoch_decay
        self.pool_size = pool_size
        self.gradient_penalty_mode = gradient_penalty_mode
        self.gradient_penalty_weight = gradient_penalty_weight
        self.cycle_loss_weight = cycle_loss_weight
        self.identity_loss_weight = identity_loss_weight
        self.color_depth = color_depth
        self.adversarial_loss_mode = adversarial_loss_mode
        self.batch_size = batch_size
        self.beta_1 = beta_1
        self.color_depth = color_depth
        self.dataset = dataset
        self.datasets_dir = datasets_dir
        self.image_ext = image_ext
        self.progrssive = progrssive
        self.lr = lr

        self.crop_size = crop_size
        self.load_size = load_size

        self.A_img_paths = py.glob(py.join(datasets_dir, dataset, 'trainA'),
                                   '*.{}'.format(image_ext))
        self.B_img_paths = py.glob(py.join(datasets_dir, dataset, 'trainB'),
                                   '*.{}'.format(image_ext))

        # summary
        self.train_summary_writer = tf.summary.create_file_writer(
            py.join(self.output_dataset_dir, 'summaries', 'train'))
示例#3
0
    choices=["gan", "hinge_v1", "hinge_v2", "lsgan", "wgan"],
)
py.arg("--cycle_loss_weight", type=float, default=10.0)
py.arg("--identity_loss_weight", type=float, default=0.0)
py.arg("--resnet_blocks", type=int, default=9)
py.arg("--DnCNN", type=str, default=None)
py.arg("--bidirectional", type=bool, default=True)
py.arg("--pool_size", type=int, default=50)  # pool size to store fake samples
args = py.args()

# output_dir
output_dir = py.join("output", args.output_dir)
py.mkdir(output_dir)

# save settings
py.args_to_yaml(py.join(output_dir, "settings.yml"), args)

# ==============================================================================
# =                                    data                                    =
# ==============================================================================

A_img_paths = py.glob(py.join(args.datasets_dir, args.dataset, "trainA"),
                      "*.jpg")
B_img_paths = py.glob(py.join(args.datasets_dir, args.dataset, "trainB"),
                      "*.jpg")
A_B_dataset, len_dataset = data.make_zip_dataset(
    A_img_paths,
    B_img_paths,
    args.batch_size,
    args.load_size,
    args.crop_size,
示例#4
0
py.arg('--epochs', type=int, default=225)
py.arg('--epoch_decay', type=int, default=25)  # epoch to start decaying learning rate
py.arg('--lr', type=float, default=0.000001)
py.arg('--beta_1', type=float, default=0.5)
py.arg('--adversarial_loss_mode', default='gan', choices=['gan', 'lsgan'])
# py.arg('--gradient_penalty_mode', default='none', choices=['none', 'dragan', 'wgan-gp'])
py.arg('--gradient_penalty_weight', type=float, default=10.0)
py.arg('--cycle_loss_weight', type=float, default=12.0)
py.arg('--identity_loss_weight', type=float, default=0.5)
py.arg('--output_dir',default='./Results')
py.arg('--pool_size', type=int, default=50)  # pool size to store fake samples
args = py.args()

if not exists(args.output_dir):
        makedirs(args.output_dir)
py.args_to_yaml(join(args.output_dir, 'settings.yml'), args)


# ==============================================================================
# =                                    data                                    =
# ==============================================================================

A_img_paths = glob.glob(join(args.datasets_dir, 'MRI', '*.png'))
B_img_paths = glob.glob(join(args.datasets_dir, 'CT', '*.png'))

A_B_dataset, len_dataset = data.make_zip_dataset(A_img_paths, B_img_paths, args.batch_size, args.size, training=True, repeat=False)

A2B_pool = data.ItemPool(args.pool_size)
B2A_pool = data.ItemPool(args.pool_size)

A_img_paths_test = glob.glob(join(args.datasets_dir, 'MRI_test', '*.png'))