def __init__(self, sm_writer, model_helper): """Constructor function. Args: * sm_writer: TensorFlow's summary writer * model_helper: model helper with definitions of model & dataset """ # initialize attributes self.sm_writer = sm_writer self.data_scope = 'data' self.model_scope = 'model' # initialize Horovod / TF-Plus for multi-gpu training if FLAGS.enbl_multi_gpu: mgw.init() from mpi4py import MPI self.mpi_comm = MPI.COMM_WORLD else: self.mpi_comm = None # obtain the function interface provided by the model helper self.build_dataset_train = model_helper.build_dataset_train self.build_dataset_eval = model_helper.build_dataset_eval self.forward_train = model_helper.forward_train self.forward_eval = model_helper.forward_eval self.calc_loss = model_helper.calc_loss self.model_name = model_helper.model_name self.dataset_name = model_helper.dataset_name # checkpoint path determined by model's & dataset's names self.ckpt_file = 'models_%s_at_%s.tar.gz' % (self.model_name, self.dataset_name)
def main(unused_argv): """Main entry. Args: * unused_argv: unused arguments (after FLAGS is parsed) """ tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.enbl_multi_gpu: mgw.init() trainer = Trainer(data_path=FLAGS.data_path, netcfg=FLAGS.net_cfg) trainer.build_graph(is_train=True) trainer.build_graph(is_train=False) if FLAGS.eval_only: trainer.eval() else: trainer.train()