class Evaler(object): @staticmethod def get_model_class(model_name): if model_name == 'mlp': from model_mlp import Model if model_name == 'conv': from model_conv import Model else: return ValueError(model_name) return Model def __init__(self, config, dataset): self.config = config self.train_dir = config.train_dir log.info("self.train_dir = %s", self.train_dir) # --- input ops --- self.batch_size = config.batch_size self.dataset = dataset check_data_id(dataset, config.data_id) _, self.batch = create_input_ops(dataset, self.batch_size, data_id=config.data_id, is_training=False, shuffle=False) # --- create model --- Model = self.get_model_class(config.model) self.model = Model(config) self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) self.step_op = tf.no_op(name='step_no_op') tf.set_random_seed(1234) session_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), device_count={'GPU': 1}, ) self.session = tf.Session(config=session_config) # --- checkpoint and monitoring --- self.saver = tf.train.Saver(max_to_keep=100) self.checkpoint_path = config.checkpoint_path if self.checkpoint_path is None and self.train_dir: self.checkpoint_path = tf.train.latest_checkpoint(self.train_dir) if self.checkpoint_path is None: log.warn("No checkpoint is given. Just random initialization :-)") self.session.run(tf.global_variables_initializer()) else: log.info("Checkpoint path : %s", self.checkpoint_path) def eval_run(self): # load checkpoint if self.checkpoint_path: self.saver.restore(self.session, self.checkpoint_path) log.info("Loaded from checkpoint!") log.infov("Start 1-epoch Inference and Evaluation") log.info("# of examples = %d", len(self.dataset)) length_dataset = len(self.dataset) max_steps = int(length_dataset / self.batch_size) + 1 log.info("max_steps = %d", max_steps) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(self.session, coord=coord, start=True) evaler = PoseEvalManager() try: for s in xrange(max_steps): step, loss, step_time, batch_chunk, prediction_pred, prediction_gt = \ self.run_single_step(self.batch) self.log_step_message(s, loss, step_time) evaler.add_batch(batch_chunk['id'], prediction_pred, prediction_gt) except Exception as e: coord.request_stop(e) coord.request_stop() try: coord.join(threads, stop_grace_period_secs=3) except RuntimeError as e: log.warn(str(e)) evaler.report() log.infov("Evaluation complete.") def run_single_step(self, batch, step=None, is_train=True): _start_time = time.time() batch_chunk = self.session.run(batch) [step, accuracy, all_preds, all_targets, _] = self.session.run( [self.global_step, self.model.accuracy, self.model.all_preds, self.model.all_targets, self.step_op], feed_dict=self.model.get_feed_dict(batch_chunk) ) _end_time = time.time() return step, accuracy, (_end_time - _start_time), batch_chunk, all_preds, all_targets def log_step_message(self, step, accuracy, step_time, is_train=False): if step_time == 0: step_time = 0.001 log_fn = (is_train and log.info or log.infov) log_fn((" [{split_mode:5s} step {step:4d}] " + "batch total-accuracy (test): {test_accuracy:.2f}% " + "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) " ).format(split_mode=(is_train and 'train' or 'val'), step=step, test_accuracy=accuracy*100, sec_per_batch=step_time, instance_per_sec=self.batch_size / step_time, ) )
class Trainer(object): @staticmethod def get_model_class(model_name): if model_name == 'mlp': from model_mlp import Model elif model_name == 'conv': from model_conv import Model else: raise ValueError(model_name) return Model def __init__(self, config, dataset, dataset_test): self.config = config hyper_parameter_str = config.dataset+'_lr_'+str(config.learning_rate)+'_update_G'+str(config.update_rate)+'_D'+str(1) self.train_dir = './train_dir/%s-%s-%s-%s' % ( config.model, config.prefix, hyper_parameter_str, time.strftime("%Y%m%d-%H%M%S") ) if not os.path.exists(self.train_dir): os.makedirs(self.train_dir) log.infov("Train Dir: %s", self.train_dir) # --- input ops --- self.batch_size = config.batch_size _, self.batch_train = create_input_ops(dataset, self.batch_size, is_training=True) _, self.batch_test = create_input_ops(dataset_test, self.batch_size, is_training=False) # --- create model --- Model = self.get_model_class(config.model) log.infov("Using Model class : %s", Model) self.model = Model(config) # --- optimizer --- self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) self.learning_rate = config.learning_rate if config.lr_weight_decay: self.learning_rate = tf.train.exponential_decay( self.learning_rate, global_step=self.global_step, decay_steps=10000, decay_rate=0.5, staircase=True, name='decaying_learning_rate' ) # print all the trainable variables #tf.contrib.slim.model_analyzer.analyze_vars(tf.trainable_variables(), print_info=True) #self.check_op = tf.add_check_numerics_ops() self.check_op = tf.no_op() # --- checkpoint and monitoring --- all_vars = tf.trainable_variables() d_var = [v for v in all_vars if v.name.startswith('Discriminator')] log.warn("********* d_var ********** "); slim.model_analyzer.analyze_vars(d_var, print_info=True) g_var = [v for v in all_vars if v.name.startswith(('Generator'))] log.warn("********* g_var ********** "); slim.model_analyzer.analyze_vars(g_var, print_info=True) rem_var = (set(all_vars) - set(d_var) - set(g_var)) print([v.name for v in rem_var]); assert not rem_var self.d_optimizer = tf.contrib.layers.optimize_loss( loss=self.model.d_loss, global_step=self.global_step, learning_rate=self.learning_rate*0.5, optimizer=tf.train.AdamOptimizer(beta1=0.5), clip_gradients=20.0, name='d_optimize_loss', variables=d_var ) self.g_optimizer = tf.contrib.layers.optimize_loss( loss=self.model.g_loss, global_step=self.global_step, learning_rate=self.learning_rate, optimizer=tf.train.AdamOptimizer(beta1=0.5), clip_gradients=20.0, name='g_optimize_loss', variables=g_var ) self.summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(max_to_keep=100) self.summary_writer = tf.summary.FileWriter(self.train_dir) self.checkpoint_secs = 600 # 10 min self.supervisor = tf.train.Supervisor( logdir=self.train_dir, is_chief=True, saver=None, summary_op=None, summary_writer=self.summary_writer, save_summaries_secs=300, save_model_secs=self.checkpoint_secs, global_step=self.global_step, ) session_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), device_count={'GPU': 1}, ) self.session = self.supervisor.prepare_or_wait_for_session(config=session_config) self.ckpt_path = config.checkpoint if self.ckpt_path is not None: log.info("Checkpoint path: %s", self.ckpt_path) self.pretrain_saver.restore(self.session, self.ckpt_path) log.info("Loaded the pretrain parameters from the provided checkpoint path") def train(self): log.infov("Training Starts!") pprint(self.batch_train) max_steps = 1000000 output_save_step = 1000 test_sample_step = 100 for s in xrange(max_steps): step, accuracy, summary, d_loss, g_loss, s_loss, step_time, prediction_train, gt_train, g_img = \ self.run_single_step(self.batch_train, step=s, is_train=True) # periodic inference if s % test_sample_step == 0: accuracy_test, prediction_test, gt_test = \ self.run_test(self.batch_test, is_train=False) else: accuracy_test = 0.0 if s % 10 == 0: self.log_step_message(step, accuracy, accuracy_test, d_loss, g_loss, s_loss, step_time) self.summary_writer.add_summary(summary, global_step=step) if s % output_save_step == 0: log.infov("Saved checkpoint at %d", s) save_path = self.saver.save(self.session, os.path.join(self.train_dir, 'model'), global_step=step) f = h5py.File(os.path.join(self.train_dir, 'g_img_'+str(s)+'.hy'), 'w') f['image'] = g_img f.close() def run_single_step(self, batch, step=None, is_train=True): _start_time = time.time() batch_chunk = self.session.run(batch) fetch = [self.global_step, self.model.accuracy, self.summary_op, self.model.d_loss, self.model.g_loss, self.model.S_loss, self.model.all_preds, self.model.all_targets, self.model.fake_img, self.check_op] if step%(self.config.update_rate+1) > 0: # Train the generator fetch.append(self.g_optimizer) else: # Train the discriminator fetch.append(self.d_optimizer) fetch_values = self.session.run(fetch, feed_dict=self.model.get_feed_dict(batch_chunk, step=step) ) [step, loss, summary, d_loss, g_loss, s_loss, all_preds, all_targets, g_img] = fetch_values[:9] _end_time = time.time() return step, loss, summary, d_loss, g_loss, s_loss, (_end_time - _start_time), all_preds, all_targets, g_img def run_test(self, batch, is_train=False, repeat_times=8): batch_chunk = self.session.run(batch) [step, loss, all_preds, all_targets] = self.session.run( [self.global_step, self.model.accuracy, self.model.all_preds, self.model.all_targets], feed_dict=self.model.get_feed_dict(batch_chunk, is_training=False)) return loss, all_preds, all_targets def log_step_message(self, step, accuracy, accuracy_test, d_loss, g_loss, s_loss, step_time, is_train=True): if step_time == 0: step_time = 0.001 log_fn = (is_train and log.info or log.infov) log_fn((" [{split_mode:5s} step {step:4d}] " + "Supervised loss: {s_loss:.5f} " + "D loss: {d_loss:.5f} " + "G loss: {g_loss:.5f} " + "Accuracy: {accuracy:.5f} " "test loss: {test_loss:.5f} " + "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) " ).format(split_mode=(is_train and 'train' or 'val'), step = step, d_loss = d_loss, g_loss = g_loss, s_loss = s_loss, accuracy = accuracy, test_loss = accuracy_test, sec_per_batch = step_time, instance_per_sec = self.batch_size / step_time ) )