def run(self): # Setup logging & log the version. if not FLAGS.debug: tf.logging.set_verbosity(logging.INFO) else: tf.logging.set_verbosity(logging.DEBUG) logging.info("Tensorflow version: {}.".format(tf.__version__)) self.train_dir = FLAGS.train_dir self.logs_dir = "{}_logs".format(self.train_dir) if FLAGS.eval_num_gpu: self.batch_size = \ FLAGS.eval_batch_size * FLAGS.eval_num_gpu else: self.batch_size = FLAGS.eval_batch_size local_device_protos = device_lib.list_local_devices() gpus = [x.name for x in local_device_protos if x.device_type == 'GPU'] gpus = gpus[:FLAGS.eval_num_gpu] num_gpus = len(gpus) if num_gpus > 0: logging.info("Using the {} GPUs".format(num_gpus)) self.num_towers = num_gpus self.device_string = '/gpu:{}' logging.info("Using total batch size of {} for evaluation " "over {} GPUs: batch size of {} per GPUs.".format( self.batch_size, self.num_towers, self.batch_size // self.num_towers)) else: logging.info("No GPUs found. Eval on CPU.") self.num_towers = 1 self.device_string = '/cpu:{}' logging.info("Using total batch size of {} for evalauton " "on CPU.".format(self.batch_size)) pp = pprint.PrettyPrinter(indent=2, compact=True) logging.info(pp.pformat(FLAGS.values())) self.config = tf.ConfigProto( log_device_placement=FLAGS.log_device_placement, allow_soft_placement=True) with tf.Graph().as_default(): self.reader = find_class_by_name(FLAGS.reader, [readers])( self.batch_size, is_training=False) if FLAGS.eval_under_attack: attack_method = FLAGS.attack_method attack_cls = getattr(attacks, attack_method, None) if attack_cls is None: raise ValueError("Attack is not recognized.") attack_config = getattr(FLAGS, attack_method) self.attack = attack_cls( batch_size=self.batch_size, sample=FLAGS.attack_sample, **attack_config) data_pattern = FLAGS.data_pattern self.dataset = re.findall("[a-z0-9]+", data_pattern.lower())[0] if data_pattern is "": raise IOError("'data_pattern' was not specified. " "Nothing to evaluate.") self.model = find_class_by_name(FLAGS.model, [models])() self.loss_fn = find_class_by_name(FLAGS.loss, [losses])() self.build_graph() logging.info("Built evaluation graph") if FLAGS.eval_under_attack: self.saver = tf.train.Saver(tf.global_variables(scope="tower")) acc_val, acc_adv_val = self.eval_attack() # filename = "score_{}.txt".format(self.attack.get_name()) path = join(self.logs_dir, "attacks_score.txt") with open(path, 'a') as f: f.write("{}\n".format(FLAGS.attack_method)) f.write("sample {}, {}\n".format(FLAGS.attack_sample, json.dumps(attack_config))) f.write("{:.5f}\t{:.5f}\n\n".format(acc_val, acc_adv_val)) else: self.saver = tf.train.Saver(tf.global_variables()) filename_suffix= "_{}_{}".format("eval", re.findall("[a-z0-9]+", data_pattern.lower())[0]) self.summary_writer = tf.summary.FileWriter( self.train_dir, filename_suffix=filename_suffix, graph=tf.get_default_graph()) if FLAGS.stopped_at_n == "auto": one_epoch = self.reader.n_train_files / \ (FLAGS.train_batch_size * FLAGS.train_num_gpu) self.stopped_at_n = (FLAGS.num_epochs * one_epoch) // FLAGS.save_checkpoint_steps else: self.stopped_at_n = FLAGS.stopped_at_n logging.info("Making evaluation for {} ckpts.".format( int(self.stopped_at_n))) self.best_global_step = None self.best_accuracy = None self.counter = 0 last_global_step_val = 0 while self.counter < self.stopped_at_n: last_global_step_val = self.eval_loop(last_global_step_val) path = join(self.logs_dir, "best_accuracy.txt") with open(path, 'w') as f: f.write("{}\t{:.4f}\n".format(self.best_global_step, self.best_accuracy)) logging.info("Done evaluation -- number of eval reached.")
def run(self, start_new_model=False): """Performs training on the currently defined Tensorflow graph. Returns: A tuple of the training Hit@1 and the training PERR. """ if self.is_master and start_new_model and exists(self.train_dir): self.remove_training_directory(self.train_dir) pp = pprint.PrettyPrinter(indent=2, compact=True) logging.info(pp.pformat(FLAGS.values())) model_flags_dict = FLAGS.to_json() log_folder = '{}_logs'.format(self.train_dir) flags_json_path = join(log_folder, "model_flags.json") if not exists(flags_json_path): # Write the file. with open(flags_json_path, "w") as fout: fout.write(model_flags_dict) target, device_fn = self.start_server_if_distributed() meta_filename = self.get_meta_filename(start_new_model, self.train_dir) with tf.Graph().as_default() as graph: if meta_filename: saver = self.recover_model(meta_filename) with tf.device(device_fn): if not meta_filename: saver = self.build_model(self.model, self.reader) global_step = tf.train.get_global_step() loss = tf.get_collection("loss")[0] logits = tf.get_collection("logits")[0] labels = tf.get_collection("labels")[0] learning_rate = tf.get_collection("learning_rate")[0] train_op = tf.get_collection("train_op")[0] summary_op = tf.get_collection("summary_op")[0] init_op = tf.global_variables_initializer() gradients_norm = tf.get_collection("gradients_norm")[0] scaffold = tf.train.Scaffold( saver=saver, init_op=init_op, summary_op=summary_op, ) hooks = [ tf.train.NanTensorHook(loss), tf.train.StopAtStepHook(num_steps=self.max_steps), ] session_args = dict( is_chief=self.is_master, scaffold=scaffold, checkpoint_dir=FLAGS.train_dir, hooks=hooks, save_checkpoint_steps=FLAGS.save_checkpoint_steps, save_summaries_steps=10, save_summaries_secs=None, log_step_count_steps=0, config=self.config, ) logging.info("Start training") with tf.train.MonitoredTrainingSession(**session_args) as sess: summary_writer = tf.summary.FileWriterCache.get( FLAGS.train_dir) if FLAGS.profiler: profiler = tf.profiler.Profiler(sess.graph) global_step_val = 0 while not sess.should_stop(): make_profile = False profile_args = {} if global_step_val % 1000 == 0 and FLAGS.profiler: make_profile = True run_meta = tf.RunMetadata() profile_args = { 'options': tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), 'run_metadata': run_meta } fetches = OrderedDict(train_op=train_op, global_step=global_step, loss=loss, learning_rate=learning_rate, logits=logits, labels=labels) if gradients_norm != 0: fetches['gradients_norm'] = gradients_norm else: grad_norm_val = 0 batch_start_time = time.time() values = sess.run(list(fetches.values()), **profile_args) fetches_values = OrderedDict(zip(fetches.keys(), values)) seconds_per_batch = time.time() - batch_start_time examples_per_second = self.batch_size / seconds_per_batch global_step_val = fetches_values['global_step'] loss_val = fetches_values['loss'] learning_rate_val = fetches_values['learning_rate'] predictions_val = fetches_values['logits'] labels_val = fetches_values['labels'] if gradients_norm != 0: grad_norm_val = fetches_values['gradients_norm'] if FLAGS.gradients['compute_hessian'] and global_step_val != 0 and \ global_step_val % FLAGS.gradients['hessian_every_n_step'] == 0: compute_hessian_and_summary(sess, summary_writer, global_step_val) if make_profile and FLAGS.profiler: profiler.add_step(global_step_val, run_meta) # Profile the parameters of your model. profiler.profile_name_scope( options=(tf.profiler.ProfileOptionBuilder. trainable_variables_parameter())) # Or profile the timing of your model operations. opts = tf.profiler.ProfileOptionBuilder.time_and_memory( ) profiler.profile_operations(options=opts) # Or you can generate a timeline: opts = (tf.profiler.ProfileOptionBuilder( tf.profiler.ProfileOptionBuilder.time_and_memory( )).with_step(global_step_val).with_timeline_output( '~/profile.logs').build()) profiler.profile_graph(options=opts) to_print = global_step_val % FLAGS.frequency_log_steps == 0 if (self.is_master and to_print) or global_step_val == 1: epoch = ((global_step_val * self.batch_size) / self.reader.n_train_files) message = MessageBuilder() message.add("epoch", epoch, format="4.2f") message.add("step", global_step_val, width=5, format=".0f") message.add("lr", learning_rate_val, format=".6f") message.add("loss", loss_val, format=".4f") if "YT8M" in self.reader.__class__.__name__: gap = eval_util.calculate_gap( predictions_val, labels_val) message.add("gap", gap, format=".3f") message.add("imgs/sec", examples_per_second, width=5, format=".0f") if FLAGS.gradients['perturbed_gradients']: message.add("grad norm", grad_norm_val, format=".4f") logging.info(message.get_message()) # End training logging.info( "{}: Done training -- epoch limit reached.".format( task_as_string(self.task))) if FLAGS.profiler: profiler.advise() logging.info("{}: Exited training loop.".format( task_as_string(self.task)))
def run(self): tf.set_random_seed(0) # for reproducibility # Setup logging & log the version. tf.set_random_seed(0) # for reproducibility # Setup logging & log the version. tf.logging.set_verbosity(logging.INFO) logging.info("Tensorflow version: {}.".format(tf.__version__)) if os.environ.get('CUDA_VISIBLE_DEVICES') is None: if FLAGS.eval_num_gpu == 0: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' else: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( map(str, range(FLAGS.eval_num_gpu))) # self.train_dir = join(FLAGS.path, FLAGS.train_dir) self.train_dir = FLAGS.train_dir pp = pprint.PrettyPrinter(indent=2, compact=True) logging.info(pp.pformat(FLAGS.values())) with tf.Graph().as_default(): if FLAGS.eval_num_gpu: self.batch_size = \ FLAGS.eval_batch_size * FLAGS.eval_num_gpu else: self.batch_size = FLAGS.eval_batch_size self.reader = find_class_by_name(FLAGS.reader, [readers])(self.batch_size, is_training=False) self.model = find_class_by_name(FLAGS.model, [models])() self.loss_fn = find_class_by_name(FLAGS.loss, [losses])() data_pattern = FLAGS.data_pattern if data_pattern is "": raise IOError("'data_pattern' was not specified. " "Nothing to evaluate.") self.build_graph() logging.info("Built evaluation graph") self.saver = tf.train.Saver(tf.global_variables()) filename_suffix = "_{}_{}".format( "eval", re.findall("[a-z0-9]+", data_pattern.lower())[0]) self.summary_writer = tf.summary.FileWriter( self.train_dir, filename_suffix=filename_suffix, graph=tf.get_default_graph()) evl_metrics = eval_util.EvaluationMetrics(self.reader.n_classes, 20) self.counter = 0 last_global_step_val = 0 while self.counter < FLAGS.stopped_at_n: last_global_step_val = self.eval_loop(last_global_step_val, evl_metrics) logging.info("Done evaluation -- number of eval reached.")