Example #1
0
class Config(object):
    '''
    wraps all configuration parameters in 'static' variables
    '''
    ##
    # network parameters
    nz          = 3                  # num of dim for Z at each field position
    zx          = 5                    # number of spatial dimensions in Z
    zx_sample   = 10                    # size of the spatial dimension in Z for producing the samples
	# l = h/r, m = w/r
    nc          = 1                     # number of channels in input X (i.e. r,g,b) #3 
    gen_ks      = ([(3,3)] * 5)[::-1]   # kernel sizes on each layer - should be odd numbers for zero-padding stuff   
    dis_ks      = [(3,3)] * 5    # kernel sizes on each layer - should be odd numbers for zero-padding stuff
    gen_ls      = len(gen_ks)           # num of layers in the generative network
    dis_ls      = len(dis_ks)           # num of layers in the discriminative network
    gen_fn      = [nc]+[2**(n+6) for n in range(gen_ls-1)]  # generative number of filters
    gen_fn      = gen_fn[::-1]
    dis_fn      = [2**(n+6) for n in range(dis_ls-1)]+[1]   # discriminative number of filters

    lr          = 0.0005                # learning rate of adam
    b1          = 0.5                   # momentum term of adam
    l2_fac      = 1e-5                  # L2 weight regularization factor #1e-5

    batch_size  = 64 
	
    epoch_iters = batch_size * 100

    k           = 1                     # number of D updates vs G updates

    npx         = zx_to_npx(zx, gen_ls) # num of pixels width/height of images in X

    ##
    # data input folder
    sub_name    = 'ti_braided_river'
    home        = os.path.expanduser("~")
    texture_dir = home + "/SGANinv/SGAN/2D/training/%s/" % sub_name
    data_iter   = get_texture_iter(texture_dir, npx=npx, mirror=False, batch_size=batch_size,n_channel=nc)

    save_name   = sub_name+ "_filters%d_npx%d_%dgL_%ddL" % (dis_fn[0],npx,gen_ls, dis_ls)

    load_name   = None                  # if None, initializing network from scratch
    # load_name   = "braided_river_filters64_npx129_5gL_5dL_epoch92.sgan"


    @classmethod
    def print_info(cls):
        ##
        # output some information
        print("Learning and generating samples from zx ", cls.zx, ", which yields images of size npx ", zx_to_npx(cls.zx, cls.gen_ls))
        print("Producing samples from zx_sample ", cls.zx_sample, ", which yields images of size npx ", zx_to_npx(cls.zx_sample, cls.gen_ls)) 
        print("Saving samples and model data to file ", cls.save_name)
Example #2
0
 def data_iter(self, texture_path, batch_size, inverse=0, n_samples=20):
     return get_texture_iter(texture_path,
                             npx=self.npx,
                             batch_size=batch_size,
                             inverse=inverse,
                             n_samples=n_samples)
Example #3
0
 def data_iter(self):
     return get_texture_iter(self.texture_dir,
                             npx=self.npx,
                             mirror=False,
                             batch_size=self.batch_size)