Beispiel #1
0
def get_config():
    config = get_default_configs()
    # training
    training = config.training
    training.sde = 'vesde'
    training.continuous = True

    # sampling
    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'reverse_diffusion'
    sampling.corrector = 'langevin'

    # data
    data = config.data
    data.dataset = 'CUSTOM'
    data.image_size = 128
    training.batch_size = 16
    data.tfrecords_path = "/content/drive/MyDrive/Training/tf_dataset/tf_dataset-r07.tfrecords"

    # model
    model = config.model
    model.name = 'ncsnpp'
    model.sigma_max = 217
    model.scale_by_sigma = True
    model.ema_rate = 0.999
    model.normalization = 'GroupNorm'
    model.nonlinearity = 'swish'
    model.nf = 128
    model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
    model.num_res_blocks = 2
    model.attn_resolutions = (16, )
    model.resamp_with_conv = True
    model.conditional = True
    model.fir = True
    model.fir_kernel = [1, 3, 3, 1]
    model.skip_rescale = True
    model.resblock_type = 'biggan'
    model.progressive = 'output_skip'
    model.progressive_input = 'input_skip'
    model.progressive_combine = 'sum'
    model.attention_type = 'ddpm'
    model.init_scale = 0.
    model.fourier_scale = 16
    model.conv_size = 3

    return config
Beispiel #2
0
def get_config():
    config = get_default_configs()

    # training
    training = config.training
    training.sde = 'vpsde'
    training.continuous = False
    training.reduce_mean = True

    # sampling
    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'ancestral_sampling'
    sampling.corrector = 'none'

    # data
    data = config.data
    data.dataset = 'CelebAHQ'
    data.centered = True
    data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords'
    data.image_size = 256

    # model
    model = config.model
    model.name = 'ddpm'
    model.scale_by_sigma = False
    model.num_scales = 1000
    model.ema_rate = 0.9999
    model.normalization = 'GroupNorm'
    model.nonlinearity = 'swish'
    model.nf = 128
    model.ch_mult = (1, 1, 2, 2, 4, 4)
    model.num_res_blocks = 2
    model.attn_resolutions = (16, )
    model.resamp_with_conv = True
    model.conditional = True

    # optim
    optim = config.optim
    optim.lr = 2e-5

    return config
Beispiel #3
0
def get_config():
    config = get_default_configs()
    # training
    training = config.training
    training.batch_size = 128
    training.sde = 'vesde'
    training.continuouse = False
    # sampling
    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'none'
    sampling.corrector = 'ald'
    sampling.n_steps_each = 3
    sampling.snr = 0.095
    # data
    data = config.data
    data.category = 'bedroom'
    data.image_size = 128
    # model
    model = config.model
    model.name = 'ncsnv2_128'
    model.scale_by_sigma = True
    model.sigma_max = 190
    model.num_scales = 1086
    model.ema_rate = 0.9999
    model.sigma_min = 0.01
    model.normalization = 'InstanceNorm++'
    model.nonlinearity = 'elu'
    model.nf = 128
    model.interpolation = 'bilinear'
    # optim
    optim = config.optim
    optim.weight_decay = 0
    optim.optimizer = 'Adam'
    optim.lr = 1e-4
    optim.beta1 = 0.9
    optim.amsgrad = False
    optim.eps = 1e-8
    optim.warmup = 0
    optim.grad_clip = -1

    return config
Beispiel #4
0
def get_config():
    config = get_default_configs()

    # training
    training = config.training
    training.sde = 'vpsde'
    training.continuous = False
    training.reduce_mean = True

    # sampling
    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'ancestral_sampling'
    sampling.corrector = 'none'

    # data
    data = config.data
    data.category = 'church_outdoor'
    data.centered = True

    # model
    model = config.model
    model.name = 'ddpm'
    model.scale_by_sigma = False
    model.num_scales = 1000
    model.ema_rate = 0.9999
    model.normalization = 'GroupNorm'
    model.nonlinearity = 'swish'
    model.nf = 128
    model.ch_mult = (1, 1, 2, 2, 4, 4)
    model.num_res_blocks = 2
    model.attn_resolutions = (16, )
    model.resamp_with_conv = True
    model.conditional = True

    # optim
    optim = config.optim
    optim.lr = 2e-5

    return config