class Solver(object): """Solver for generic networks. """ def __init__(self, net, graph, is_training): self.net = net self.graph = graph self.is_training = is_training self.num_epochs = cfg.TRAIN.NUM_EPOCHS self.train_timer = Timer() self.data_timer = Timer() self.global_step = slim.get_or_create_global_step() # Build basic ops and tensors if self.is_training: self.net.build_train_ops(self.global_step) if isinstance(net, Classifier): self.saver = tf.train.Saver(var_list=net.vars_to_restore, name='saver') else: self.saver = tf.train.Saver(max_to_keep=None, name='saver_all_var') # Save all vars self.init_ops = self.build_init_ops() self.val_loss_ph = tf.placeholder(tf.float32, shape=(), name='val_loss_ph') self.net.build_summary_ops(self.graph) self.val_loss_summary = tf.summary.scalar(name='val_loss', tensor=self.val_loss_ph) print('saver variables:') print_tensor_shapes(net.vars_to_restore, prefix='-->') def build_init_ops(self): """Builds the init ops. Returns: init_op: Initialization op. ready_op: Initialization op. local_init_op: Initialization op. """ with tf.name_scope('init_ops'): init_op = tf.global_variables_initializer() ready_op = tf.report_uninitialized_variables() local_init_op = tf.group(tf.local_variables_initializer(), tf.tables_initializer()) return init_op, ready_op, local_init_op def restore_checkpoint(self, sess): """Restores the network to a previously saved checkpoint if a path is provided from the config. Args: sess: Current session. """ if cfg.DIR.CKPT_PATH is not None: tf.logging.info('Restoring checkpoint.') self.saver.restore(sess, cfg.DIR.CKPT_PATH) else: tf.logging.info('Using network with random weights.') def train_step(self, sess, train_queue, step): """Executes a train step, including saving the summaries if appropriate. Args: sess: Current session. train_queue: Data queue containing train set minibatches. step: The current training iteration (0-based indexing). Returns: print_dict: Dictionary of items such as losses (for just one minibatch) to print. """ print_dict = self.net.train_step(sess, train_queue, step, data_timer=self.data_timer) return print_dict def val_step(self, sess, val_queue): """Executes a validation step, which simply computes the loss. Args: sess: Current session. val_queue: Data queue containing validation set minibatches. Returns: val_loss: Loss for a single minibatch of validation data. """ raise NotImplementedError('Must be implemented by a subclass.') def validate(self, sess, val_queue, step, num_val_iter): raise NotImplementedError('Must be implemented by a subclass.') def train(self, train_iters_per_epoch, train_queue, val_iters_per_epoch=None, val_queue=None): """Train the network, computing the validation loss if val_iters_per_epoch and val_queue are provided. Args: train_iters_per_epoch: Number of iterations in a single epoch of train data, as computed by the data process. train_queue: Data queue containing minibatches of train data. val_iters_per_epoch: Optional input describing the number of iterations in a single epoch of validation data, as computed by the data process. val_queue: Optional input representing the data queue containing minibatches of validation data. """ if (val_iters_per_epoch is None and val_queue is not None) or \ (val_iters_per_epoch is not None and val_queue is None): raise ValueError('Need to input both val size and val queue.') if val_iters_per_epoch is not None and val_queue is not None: run_validation = True else: run_validation = False print('-------------- BEGIN TRAINING --------------') num_train_iter = get_num_iterations(train_iters_per_epoch, num_epochs=cfg.TRAIN.NUM_EPOCHS, disp=True) num_val_iter = 20000 // cfg.CONST.BATCH_SIZE # Evaluate on roughly 20000 samples if val_iters_per_epoch is not None: num_val_iter = min(num_val_iter, val_iters_per_epoch) with tf.Session() as sess: sess.run(self.init_ops) self.restore_checkpoint(sess) # Train loop for step in range(num_train_iter): # For randomized model # self.save(sess, step) # break self.train_timer.tic() print_dict = self.train_step(sess, train_queue, step) self.train_timer.toc() if (step + 1) % cfg.CONST.PRINT_FREQ == 0: print_dict['queue size'] = (str(train_queue.qsize()) + '/' + str(cfg.CONST.QUEUE_CAPACITY)) print_dict[ 'data fetch (sec/step)'] = '%.2f' % self.data_timer.average_time print_dict[ 'train step (sec/step)'] = '%.2f' % self.train_timer.average_time print_train_step_data(print_dict, step) # Reset timers self.data_timer.reset() self.train_timer.reset() if (run_validation is True) and ( (step + 1) % cfg.TRAIN.VALIDATION_FREQ == 0): validation_val = self.validate(sess, val_queue, step, num_val_iter) if validation_val == -1: # Training termination flag tf.logging.info( 'Terminating train loop due to decreasing validation performance.' ) break else: val_summary = sess.run( self.val_loss_summary, feed_dict={self.val_loss_ph: validation_val}) self.net.summary_writer.add_summary( val_summary, (step + 1)) if (step + 1) % cfg.TRAIN.CKPT_FREQ == 0: self.save(sess, step) # save model after training self.save(sess, step) def forward_pass_batches(self, sess, minibatch_generator): """Forward pass a series of minibatches from the minibatch generator. """ minibatch_list = [] outputs_list = [] for step, minibatch in enumerate(minibatch_generator): np.random.seed(1234) try: outputs = self.net.forward_pass(sess, minibatch) except KeyError: outputs = self.net.forward_pass(sess, minibatch, full_feed_dict=True) # Reduce size of minibatch so we can pass through entire val set if isinstance(self.net, LBA): minibatch_save = { 'raw_embedding_batch': minibatch['raw_embedding_batch'], 'caption_label_batch': minibatch['caption_label_batch'], 'category_list': minibatch['category_list'], 'model_list': minibatch['model_list'], } minibatch = minibatch_save if isinstance(self.net, Classifier): minibatch_save = { 'class_label_batch': minibatch['class_label_batch'], 'model_id_list': minibatch['model_id_list'], } minibatch = minibatch_save minibatch_list.append(minibatch) outputs_list.append(outputs) if (step + 1) % 100 == 0: tf.logging.info('%s Step: %d' % (str(datetime.now()), step + 1)) return minibatch_list, outputs_list def val_phase_minibatch_generator(self, val_queue, num_val_iter): """Return a minibatch generator for the test phase. """ for step in range(num_val_iter): minibatch = val_queue.get() minibatch['test_queue'] = True yield minibatch def evaluate(self, minibatch_list, outputs_list): """Do some evaluation of the outputs. """ pass def test(self, test_process, test_queue, num_minibatches=None, save_outputs=False): """Compute (and optionally save) the outputs for the test set. This function only computes the outputs for num_minibatches minibatches. Args: test_process: Data process for the test data. test_queue: Queue containing minibatches of test data. num_minibatches: Number of minibatches to compute the outputs for. save_outputs: Boolean flag for whether or not to save the outputs. """ with tf.Session() as sess: if cfg.DIR.CKPT_PATH is None: raise ValueError('Please provide a checkpoint.') sess.run(self.init_ops) else: self.restore_checkpoint(sess) def test_phase_minibatch_generator(): for step, minibatch in enumerate( get_while_running(test_process, test_queue)): if (num_minibatches is not None) and (step == num_minibatches): break yield minibatch minibatch_generator = test_phase_minibatch_generator() minibatch_list, outputs_list = self.forward_pass_batches( sess, minibatch_generator) self.evaluate(minibatch_list, outputs_list) if save_outputs: self.save_outputs(minibatch_list, outputs_list) def save(self, sess, step): """Save a checkpoint. """ ckpt_path = os.path.join(cfg.DIR.LOG_PATH, 'model.ckpt') tf.logging.info('Saving checkpoint (step %d).', (step + 1)) self.saver.save(sess, ckpt_path, global_step=(step + 1)) def save_outputs(self, minibatch_list, outputs_list, filename=None): """Save the outputs (from the self.test). """ raise NotImplementedError('Must be implemented by a subclass.')
class TensorboardWriter(): def __init__(self, log_dir, logger, enabled): self.writer = None self.selected_module = "" if enabled: log_dir = str(log_dir) # Retrieve vizualization writer. succeeded = False for module in ["torch.utils.tensorboard", "tensorboardX"]: try: self.writer = importlib.import_module( module).SummaryWriter(log_dir) succeeded = True break except ImportError: succeeded = False self.selected_module = module if not succeeded: message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ "the 'config.json' file." logger.warning(message) self.step = 0 self.mode = '' self.tb_writer_ftns = { 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' } self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} self.timer = Timer() def set_step(self, step, mode='train'): self.mode = mode self.step = step if step == 0: self.timer.reset() else: duration = self.timer.check() self.add_scalar('steps_per_sec', 1 / duration) def __getattr__(self, name): """ If visualization is configured to use: return add_data() methods of tensorboard with additional information (step, tag) added. Otherwise: return a blank function handle that does nothing """ if name in self.tb_writer_ftns: add_data = getattr(self.writer, name, None) def wrapper(tag, data, *args, **kwargs): if add_data is not None: # add mode(train/valid) tag if name not in self.tag_mode_exceptions: tag = '{}/{}'.format(tag, self.mode) add_data(tag, data, self.step, *args, **kwargs) return wrapper else: # default action for returning methods defined in this class, set_step() for instance. try: attr = object.__getattr__(name) except AttributeError: raise AttributeError( "type object '{}' has no attribute '{}'".format( self.selected_module, name)) return attr