예제 #1
0
 def __get_data(self, data_file):
     # Data for fitting generator
     data = [img for _, img, _ in read_csv(data_file).to_records()]
     size = len(data)
     data = np.zeros((size, ) + self.img_shape, dtype=K.floatx())
     for i in tqdm(range(size)):
         I = Image.open(expand_path(data[i]))
         I = I.convert('L')  # convert image to grayscale
         I = I.resize(self.img_shape[1:])
         I = img_to_array(I)
         data[i] = I
예제 #2
0
파일: synctrex.py 프로젝트: bkfox/synctrex
    def load(self, path):
        """
        Parse a configuration file and generate all needed object.
        Raise a ValueError if there are at least one error raised when
        reading the configuration file.
        """
        def prepare_value(value):
            value.update(self.force)
            return value

        if path in self.already_loaded:
            logger.warning('file %s has already been loaded, skip -- '
                           'you might want to check for cycles', path)
            return
        self.already_loaded.append(path)

        logger.debug('open configuration %s', path)
        errors = []
        with open(path, 'r') as file:
            data = yaml.load(file)

            # imports
            imports = data.get('imports')
            if imports:
                base_dir = os.path.dirname(path)
                for i_path in utils.as_list_of(str, imports):
                    i_path = utils.expand_path(base_dir, i_path,
                                               expand_vars=True)
                    self.load(i_path)

            # syncs
            syncs = data.get('syncs')
            if syncs:
                syncs = [ (name, Sync(self, name, **prepare_value(value)))
                          for name, value in syncs.items() ]
                self.syncs.update(syncs)

        if errors:
            raise ValueError('configuration contains {} errors: abort'
                             .format(len(errors)))
예제 #3
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    # expand user name and environment variables
    FLAGS.data_dir = expand_path(FLAGS.data_dir)
    FLAGS.out_dir = expand_path(FLAGS.out_dir)
    FLAGS.out_name = expand_path(FLAGS.out_name)
    FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
    FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

    if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
    if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

    # output folders
    if FLAGS.out_name == "":
        # FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
        FLAGS.out_name = '{} - {}'.format(
            FLAGS.data_dir.split('/')[-1], FLAGS.dataset)
        if FLAGS.train:
            FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(
                FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist,
                FLAGS.output_width, FLAGS.batch_size)

    FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
    FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
    FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

    with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
        flags_dict = {k: flags.FLAGS.__flags[k] for k in flags.FLAGS.__flags}
        json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          sample_num=FLAGS.batch_size,
                          y_dim=10,
                          z_dim=FLAGS.z_dim,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          data_dir=FLAGS.data_dir,
                          out_dir=FLAGS.out_dir,
                          max_to_keep=FLAGS.max_to_keep)
        else:
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          sample_num=FLAGS.batch_size,
                          z_dim=FLAGS.z_dim,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          data_dir=FLAGS.data_dir,
                          out_dir=FLAGS.out_dir,
                          max_to_keep=FLAGS.max_to_keep)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
            if not load_success:
                raise Exception("Checkpoint not found in " +
                                FLAGS.checkpoint_dir)

        # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
        #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
        #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
        #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        #                 [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
            if FLAGS.export:
                export_dir = os.path.join(FLAGS.checkpoint_dir,
                                          'export_b' + str(FLAGS.batch_size))
                dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)

            if FLAGS.freeze:
                export_dir = os.path.join(FLAGS.checkpoint_dir,
                                          'frozen_b' + str(FLAGS.batch_size))
                dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)

            if True:
                OPTION = 1
                visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)
예제 #4
0
def txt_file_to_str(path):
    """
    Given a path to a textfile, read its contents
    """
    return open(expand_path(path)).read()
예제 #5
0
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  FLAGS.train = True
  alpha_max_str = str(FLAGS.alpha_max)
#   if FLAGS.steer:
#     print('Training with steerable G -> loading model_argminGW2_{} ...'.format(FLAGS.transform_type))
#     DCGAN = getattr(importlib.import_module('model_argminGW2_{}'.format(FLAGS.transform_type)), 'DCGAN')
#   else:
#     print('Training vanilla G -> loading model_vanilla_{} ...'.format(FLAGS.transform_type))
#     DCGAN = getattr(importlib.import_module('model_vanilla_{}'.format(FLAGS.transform_type)), 'DCGAN')

  print('Training with steerable G for {} transformation ...'.format(FLAGS.transform_type))
  if FLAGS.transform_type == 'zoom':
    if FLAGS.steer:
      from model_argminGW2_zoom import DCGAN
    else: 
      from model_vanilla_zoom import DCGAN
        
  if FLAGS.transform_type == 'shiftx':
    alpha_max_str = str(np.uint8(FLAGS.alpha_max))
    if FLAGS.steer:
      from model_argminGW2_shiftx import DCGAN
    else: 
      from model_vanilla_shiftx import DCGAN
    
  if FLAGS.transform_type == 'shifty':
    alpha_max_str = str(np.uint8(FLAGS.alpha_max))
    if FLAGS.steer:
      from model_argminGW2_shifty import DCGAN
    else: 
      from model_vanilla_shifty import DCGAN
    
  if FLAGS.transform_type == 'rot2d':
    alpha_max_str = str(np.uint8(FLAGS.alpha_max))
    if FLAGS.steer:
      from model_argminGW2_rot2d import DCGAN
    else: 
      from model_vanilla_rot2d import DCGAN
    
  augment_flag_str = 'NoAug'
  if FLAGS.aug:
    augment_flag_str = 'aug'
  
  steer_flag_str = 'vanilla'
  if FLAGS.steer:
    steer_flag_str = 'argminGW'
  else:
    if FLAGS.aug:
        steer_flag_str = 'argminW'

  if FLAGS.out_name:
    FLAGS.out_name = expand_path(FLAGS.out_name)
  else:
    FLAGS.out_name = FLAGS.transform_type+'_'+augment_flag_str+'_'+steer_flag_str+\
                     '_alphamax'+alpha_max_str+'_lr'+ str(FLAGS.learning_rate)
  print('Results will be saved in {}'.format(FLAGS.out_name))

  # expand user name and environment variables
  FLAGS.data_dir = expand_path(FLAGS.data_dir)
  FLAGS.out_dir = expand_path(FLAGS.out_dir)
#   FLAGS.out_name = expand_path(FLAGS.out_name)
  FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
  FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

  if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
  if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

  # output folders
  if FLAGS.out_name == "":
      FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
      if FLAGS.train:
        FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)

  FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
  FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
  FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

  if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

  with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
    flags_dict = {k:FLAGS[k].value for k in FLAGS}
    json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
  

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.z_dim,
          dataset_name=FLAGS.dataset,
          aug=FLAGS.aug,
          alpha_max=FLAGS.alpha_max,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir,
          out_dir=FLAGS.out_dir,
          max_to_keep=FLAGS.max_to_keep)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.z_dim,
          dataset_name=FLAGS.dataset,
          aug=FLAGS.aug,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir,
          out_dir=FLAGS.out_dir,
          max_to_keep=FLAGS.max_to_keep)

    show_all_variables()

    if FLAGS.train:
      print('>>>---Traning mode is set to {}---<<<'.format(FLAGS.train))
      time.sleep(10)
      dcgan.train(FLAGS)
    else:
      print('<<<---Testing mode--->>>')
      time.sleep(10)  
      load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
      if not load_success:
        raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir)


    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
      if FLAGS.export:
        export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size))
        dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)

      if FLAGS.freeze:
        export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size))
        dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)

      if FLAGS.visualize:
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)
예제 #6
0
def get_files_from_dir(path):
    """
    Given a path to a directory, list all of its files
    """
    return os.listdir(expand_path(path))
예제 #7
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    FLAGS.data_dir = expand_path(FLAGS.data_dir)
    FLAGS.out_dir = expand_path(FLAGS.out_dir)
    FLAGS.out_name = expand_path(FLAGS.out_name)
    FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
    FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

    if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
    if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

    if FLAGS.out_name == "":
        FLAGS.out_name = '{} - {} - {}'.format(timestamp(),
                                               FLAGS.data_dir.split('/')[-1],
                                               FLAGS.dataset)
        if FLAGS.train:
            FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(
                FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist,
                FLAGS.output_width, FLAGS.batch_size)

    #FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
    #FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
    #FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

    with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
        flags_dict = {k: FLAGS[k].value for k in FLAGS}
        json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      z_dim=FLAGS.z_dim,
                      dataset_name=FLAGS.dataset,
                      input_fname_pattern=FLAGS.input_fname_pattern,
                      crop=FLAGS.crop,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir,
                      data_dir=FLAGS.data_dir,
                      out_dir=FLAGS.out_dir,
                      max_to_keep=FLAGS.max_to_keep)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
            if not load_success:
                raise Exception("checkpoint not found in " +
                                FLAGS.checkpoint_dir)

            if FLAGS.export:
                export_dir = os.path.join(FLAGS.checkpoint_dir,
                                          'export_b' + str(FLAGS.batch_size))
                dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)

            if FLAGS.freeze:
                export_dir = os.path.join(FLAGS.checkpoint_dir,
                                          'frozen_b' + str(FLAGS.batch_size))
                dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)

            if FLAGS.visualize:
                OPTION = 1
                visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)

        sess.close()
def read_raw_image(p):
    img = Image.open(expand_path(p))

    return img