def main(_): global DT_CLIENT DT_CLIENT = domain_translation_client.DomainTranslationClient( FLAGS.domain_translation_server, FLAGS.image_hw, # sample_images_path='/home/jerryli27/PycharmProjects/image2tag/data/inference_test_data/images_sample.npy', concurrency=FLAGS.concurrency) if FLAGS.labeler_source_image_path: LABELER_CLIENT.set_image_paths( FLAGS.labeler_source_image_path, os.path.join(FLAGS.labeler_output_image_path, 'finished_images.txt'), FLAGS.labeler_sketch_folder) else: tf.logging.warning( 'labeler client not set! empty flag `labeler_source_image_path`.') global SKETCH_REFINEMENT_CLIENT if FLAGS.sketch_refinement_supervised_server: SKETCH_REFINEMENT_CLIENT = sketch_refinement_client.SketchRefinementClient( FLAGS.sketch_refinement_supervised_server, FLAGS.image_hw, concurrency=FLAGS.concurrency, supervised=True) # SKETCH_REFINEMENT_CLIENT_UNSUPERVISED = sketch_refinement_client.SketchRefinementClient( # FLAGS.sketch_refinement_unsupervised_server, FLAGS.image_hw, concurrency=FLAGS.concurrency, supervised=False # ) if FLAGS.labeler_sketch_refinement_folder: util_io.touch_folder(FLAGS.labeler_sketch_refinement_folder) print 'GPU: {}'.format(FLAGS.gpu) httpd = BaseHTTPServer.HTTPServer((FLAGS.host, FLAGS.port), MyHandler) print 'serving at', FLAGS.host, ':', FLAGS.port httpd.serve_forever()
def set_image_paths(self, image_path, finished_image_txt_path, sketch_folder, exclude_file_start={'e', 'q'}): if image_path: self.image_paths = util_io.get_all_image_paths(image_path) # Danbooru specific method to filter out nsfw images. self.image_paths = [ p for p in self.image_paths if os.path.basename(p)[0] not in exclude_file_start ] self.sketch_paths = [None for _ in range(len(self.image_paths))] self.index = 0 if finished_image_txt_path: self.done_image_txt_path = finished_image_txt_path dir = os.path.dirname(finished_image_txt_path) self.colored_sketch_pair_txt_path = os.path.join( dir, 'colored_sketch_pair.txt') util_io.touch_folder(dir) try: self.done_image_paths = set( util_io.get_all_image_paths(finished_image_txt_path)) except AssertionError: pass self.sketch_folder = sketch_folder sketches = set([ util_misc.get_no_ext_base(p) for p in util_io.get_all_image_paths(sketch_folder) ]) self.image_paths = [ p for p in self.image_paths if util_misc.get_no_ext_base(p) in sketches ] pass
def _get_shared_info(self): """Wrapper around _read_shared_info() and _generate_shared_info().""" # If shared info already exists, read it instead of generating it again. if os.path.isfile(FLAGS.shared_info_output): return self._read_shared_info(FLAGS.shared_info_output) else: shared_info_output_dir = os.path.split(FLAGS.shared_info_output)[0] if not os.path.exists(shared_info_output_dir): util_io.touch_folder(shared_info_output_dir) return self._generate_shared_info(FLAGS.train_directory, FLAGS.validation_directory, FLAGS.shared_info_output)
def main(self, ): assert not FLAGS.train_shards % FLAGS.num_threads, ( 'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards') assert not FLAGS.validation_shards % FLAGS.num_threads, ( 'Please make the FLAGS.num_threads commensurate with ' 'FLAGS.validation_shards') print('Saving results to %s' % FLAGS.output_directory) util_io.touch_folder(FLAGS.output_directory) # Run it! self._process_dataset('validation', None, FLAGS.validation_shards, '2') self._process_dataset('test', None, FLAGS.validation_shards, '1') self._process_dataset('train', None, FLAGS.train_shards, '0')
def main(_): print("""Another way to test the inference model: saved_model_cli run --dir 'path/to/export/model' \ --tag_set serve --signature_def serving_default --input_exprs 'inputs=np.ones((1,4,4,3))'""") if not FLAGS.twingan_server: print('please specify twingan_server host:port') return util_io.touch_folder(FLAGS.output_dir) img_basename = os.path.basename(FLAGS.image_path) client = TwinGANClient(FLAGS.twingan_server, FLAGS.image_hw, concurrency=FLAGS.concurrency) output_dir = os.path.join(FLAGS.output_dir, img_basename) client.do_inference(image_path=FLAGS.image_path, output_dir=output_dir) client.block_on_callback(output_dir) print('\nDone')
def main(self, ): tf.logging.set_verbosity(tf.logging.INFO) assert not FLAGS.train_shards % FLAGS.num_threads, ( 'Please make the FLAGS.num_threads mod FLAGS.train_shards == 0') assert not FLAGS.validation_shards % FLAGS.num_threads, ( 'Please make the FLAGS.num_threads mod FLAGS.validation_shards == 0') print('Saving results to %s' % FLAGS.output_directory) util_io.touch_folder(FLAGS.output_directory) shared_info = self._get_shared_info() if FLAGS.validation_directory: self._process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, shared_info, ) if FLAGS.train_directory: self._process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, shared_info, )
def main(_): print("""Another way to test the inference model: saved_model_cli run --dir 'path/to/export/model' \ --tag_set serve --signature_def serving_default --input_exprs 'inputs=np.ones((1,4,4,3))'""") if FLAGS.num_tests > 10000: print('num_tests should not be greater than 10k') return if not FLAGS.domain_translation_server: print('please specify domain_translation_server host:port') return util_io.touch_folder(FLAGS.output_dir) img_basename = os.path.basename(FLAGS.image_path) # client = DomainTranslationClient(FLAGS.domain_translation_server, FLAGS.image_hw, concurrency=FLAGS.concurrency, sample_images_path='/home/jerryli27/PycharmProjects/image2tag/data/inference_test_data/images_sample.npy') client = DomainTranslationClient(FLAGS.domain_translation_server, FLAGS.image_hw, concurrency=FLAGS.concurrency, sample_images_path='') client.do_inference(FLAGS.image_path, os.path.join(FLAGS.output_dir, img_basename)) print('\nDone')
def main(_): print("""Another way to test the inference model: saved_model_cli run --dir 'path/to/export/model' \ --tag_set serve --signature_def serving_default --input_exprs 'inputs=np.ones((1,4,4,3))'""" ) util_io.touch_folder(FLAGS.output_dir) img_basename = os.path.basename(FLAGS.image_path) # client = DomainTranslationClient(FLAGS.domain_translation_server, FLAGS.image_hw, concurrency=FLAGS.concurrency, sample_images_path='/home/jerryli27/PycharmProjects/image2tag/data/inference_test_data/images_sample.npy') client = SketchRefinementClient(FLAGS.sketch_refinement_server, FLAGS.image_hw, concurrency=FLAGS.concurrency) output_dir = os.path.join(FLAGS.output_dir, img_basename) client.do_inference(output_dir, center_point_xy=[14, 14], image_path=FLAGS.image_path, sketch_image_path=FLAGS.image_path) client.block_on_callback(output_dir) print('\nDone')
def main(_): inferer = ImageInferer() if FLAGS.input_image_path: outputs, image_paths = inferer.infer( FLAGS.input_image_path, return_image_paths=True) else: print('Generating images conditioned on random vector.') assert FLAGS.num_output >= 0, 'you have to specify the `num_output` flag for non-translational generators.' outputs, image_paths = inferer.infer( FLAGS.input_image_path, return_image_paths=True, num_output=FLAGS.num_output) if isinstance(outputs, list): util_io.touch_folder(FLAGS.output_image_path) for i, output in enumerate(outputs): util_io.imsave(os.path.join(FLAGS.output_image_path, os.path.basename(image_paths[i])), output) else: util_io.touch_folder(os.path.dirname(FLAGS.output_image_path)) util_io.imsave(FLAGS.output_image_path, outputs)
def save_images(fetches, image_dir): """Given a list of `OutputTensor`s, save the images to `image_dir`.""" if not os.path.isdir(image_dir): util_io.touch_folder(image_dir) filesets = [] now = str(int(time.time() * 1000)) for name, is_image, vals in fetches: if is_image: image_names = [] filesets.append((name, is_image, image_names)) for i, val in enumerate(vals): filename = name + '_' + now + '_' + str(i) + '.jpg' image_names.append(filename) out_path = os.path.join(image_dir, filename) with open(out_path, 'w') as f: f.write(val) else: filesets.append((name, is_image, vals)) return filesets
def face2face(input_image_path='/root/CSC4001/data/test_face'): output_image_path = '/root/CSC4001/results/face/' + input_image_path.split( '/')[-1].split('.')[0] + '.jpg' inferer = ImageInferer() if input_image_path: outputs, image_paths = inferer.infer(input_image_path, return_image_paths=True) else: print('Generating images conditioned on random vector.') assert num_output >= 0, 'you have to specify the `num_output` flag for non-translational generators.' outputs, image_paths = inferer.infer(input_image_path, return_image_paths=True, num_output=num_output) if isinstance(outputs, list): util_io.touch_folder(output_image_path) for i, output in enumerate(outputs): util_io.imsave( os.path.join(output_image_path, os.path.basename(image_paths[i])), output) else: util_io.touch_folder(os.path.dirname(output_image_path)) util_io.imsave(output_image_path, outputs) return output_image_path
def main(self): tf.logging.set_verbosity(tf.logging.INFO) # Set session_config to allow some operations to be run on cpu. session_config = tf.ConfigProto(allow_soft_placement=True, ) with tf.Graph().as_default(): ###################### # Select the dataset # ###################### dataset = self._select_dataset() ###################### # Select the network # ###################### networks = self._select_network() ##################################### # Select the preprocessing function # ##################################### image_preprocessing_fn = self._select_image_preprocessing_fn() ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) global_step = slim.create_global_step() ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## data = self._prepare_data(dataset, image_preprocessing_fn, deploy_config, ) data_batched = self._get_batch(data) batch_names = data_batched.keys() batch = data_batched.values() ############### # Is Training # ############### if FLAGS.is_training: if not os.path.isdir(FLAGS.train_dir): util_io.touch_folder(FLAGS.train_dir) if not os.path.exists(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME)): FLAGS.append_flags_into_file(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME)) try: batch_queue = slim.prefetch_queue.prefetch_queue( batch, capacity=4 * deploy_config.num_clones) except ValueError as e: tf.logging.warning('Cannot use batch_queue due to error %s', e) batch_queue = batch # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, self._clone_fn, GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype), [networks, batch_queue, batch_names], {'global_step': global_step, 'is_training': FLAGS.is_training}) first_clone_scope = deploy_config.clone_scope(0) # Add summaries for end_points. end_points = clones[0].outputs self._end_points_for_debugging = end_points self._add_end_point_summaries(end_points, summaries) # Add summaries for images, if there are any. self._add_image_summaries(end_points, summaries) # Add summaries for losses. self._add_loss_summaries(first_clone_scope, summaries, end_points) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by generator_network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): learning_rate = self._configure_learning_rate(self.num_samples, global_step) optimizer = self._configure_optimizer(learning_rate) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append(variable_averages.apply(moving_average_variables)) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) # Define optimization process. train_tensor = self._add_optimization(clones, optimizer, summaries, update_ops, global_step) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Define train_step with eval every `eval_every_n_steps`. def train_step_fn(session, *args, **kwargs): self.do_extra_train_step(session, end_points, global_step) total_loss, should_stop = slim.learning.train_step(session, *args, **kwargs) return [total_loss, should_stop] ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, train_step_fn=train_step_fn, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=self._get_init_fn(FLAGS.checkpoint_path, FLAGS.checkpoint_exclude_scopes), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None, session_config=session_config) ########################## # Eval, Export or Output # ########################## else: # Write flags file. if not os.path.isdir(FLAGS.eval_dir): util_io.touch_folder(FLAGS.eval_dir) if not os.path.exists(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME)): FLAGS.append_flags_into_file(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME)) with tf.variable_scope(tf.get_variable_scope(), custom_getter=model_deploy.get_custom_getter( GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype)), reuse=False): end_points = self._clone_fn(networks, batch_queue=None, batch_names=batch_names, data_batched=data_batched, is_training=False, global_step=global_step) num_batches = int(math.ceil(self.num_samples / float(FLAGS.batch_size))) checkpoint_path = util_misc.get_latest_checkpoint_path(FLAGS.checkpoint_path) if FLAGS.moving_average_decay: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) variables_to_restore = variable_averages.variables_to_restore( slim.get_model_variables()) variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = None if variables_to_restore is not None: saver = tf.train.Saver(variables_to_restore) session_creator = tf.train.ChiefSessionCreator( scaffold=tf.train.Scaffold(saver=saver), checkpoint_filename_with_path=checkpoint_path, master=FLAGS.master, config=session_config) ########## # Output # ########## if FLAGS.do_output: tf.logging.info('Output mode.') output_ops = self._maybe_encode_output_tensor(self._define_outputs(end_points, data_batched)) start_time = time.time() with tf.train.MonitoredSession( session_creator=session_creator) as session: for i in range(num_batches): output_results = session.run([item[-1] for item in output_ops]) self._write_outputs(output_results, output_ops) if i % FLAGS.log_every_n_steps == 0: current_time = time.time() speed = (current_time - start_time) / (i + 1) time_left = speed * (num_batches - i + 1) tf.logging.info('%d / %d done. Time left: %f', i + 1, num_batches, time_left) ################ # Export Model # ################ elif FLAGS.do_export: tf.logging.info('Exporting trained model to %s', FLAGS.export_path) with tf.Session(config=session_config) as session: saver.restore(session, checkpoint_path) builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.export_path) signature_def_map = self._build_signature_def_map(end_points, data_batched) assets_collection = self._build_assets_collection(end_points, data_batched) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') builder.add_meta_graph_and_variables( session, [tf.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map, legacy_init_op=legacy_init_op, assets_collection=assets_collection, ) builder.save() tf.logging.info('Done exporting!') ######## # Eval # ######## else: tf.logging.info('Eval mode.') # Add summaries for images, if there are any. self._add_image_summaries(end_points, None) # Define the metrics: metric_map = self._define_eval_metrics(end_points, data_batched) names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(metric_map) names_to_values = collections.OrderedDict(**names_to_values) names_to_updates = collections.OrderedDict(**names_to_updates) # Print the summaries to screen. for name, value in names_to_values.items(): summary_name = 'eval/%s' % name if len(value.shape): op = tf.summary.tensor_summary(summary_name, value, collections=[]) else: op = tf.summary.scalar(summary_name, value, collections=[]) op = tf.Print(op, [value], summary_name) tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) if not (FLAGS.do_eval_debug or FLAGS.do_custom_eval): tf.logging.info('Evaluating %s' % checkpoint_path) slim.evaluation.evaluate_once( master=FLAGS.master, checkpoint_path=checkpoint_path, logdir=FLAGS.eval_dir, num_evals=num_batches, eval_op=list(names_to_updates.values()), variables_to_restore=variables_to_restore, session_config=session_config) return ################################ # `do_eval_debug` flag is true.# ################################ if FLAGS.do_eval_debug: eval_ops = list(names_to_updates.values()) eval_names = list(names_to_updates.keys()) # Items to write to a html page. encode_ops = self._maybe_encode_output_tensor(self.get_items_to_encode(end_points, data_batched)) with tf.train.MonitoredSession(session_creator=session_creator) as session: if eval_ops is not None: for i in range(num_batches): eval_result = session.run(eval_ops, None) print('; '.join(('%s:%s' % (name, str(eval_result[i])) for i, name in enumerate(eval_names)))) # Write to HTML if encode_ops: for i in range(num_batches): encode_ops_feed_dict = self._get_encode_op_feed_dict(end_points, encode_ops, i) encoded_items = session.run([item[-1] for item in encode_ops], encode_ops_feed_dict) encoded_list = [] for j in range(len(encoded_items)): encoded_list.append((encode_ops[j][0], encode_ops[j][1], encoded_items[j].tolist())) eval_items = self.save_images(encoded_list, os.path.join(FLAGS.eval_dir, 'images')) eval_items = self.to_human_friendly(eval_items, ) self._write_eval_html(eval_items) if i % 10 == 0: tf.logging.info('%d/%d' % (i, num_batches)) if FLAGS.do_custom_eval: extra_eval = self._define_extra_eval_actions(end_points, data_batched) with tf.train.MonitoredSession(session_creator=session_creator) as session: self._do_extra_eval_actions(session, extra_eval)