Exemplo n.º 1
0
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    """

        # initialize attributes
        self.sm_writer = sm_writer
        self.data_scope = 'data'
        self.model_scope = 'model'

        # initialize Horovod / TF-Plus for multi-gpu training
        if FLAGS.enbl_multi_gpu:
            mgw.init()
            from mpi4py import MPI
            self.mpi_comm = MPI.COMM_WORLD
        else:
            self.mpi_comm = None

        # obtain the function interface provided by the model helper
        self.build_dataset_train = model_helper.build_dataset_train
        self.build_dataset_eval = model_helper.build_dataset_eval
        self.forward_train = model_helper.forward_train
        self.forward_eval = model_helper.forward_eval
        self.calc_loss = model_helper.calc_loss
        self.model_name = model_helper.model_name
        self.dataset_name = model_helper.dataset_name

        # checkpoint path determined by model's & dataset's names
        self.ckpt_file = 'models_%s_at_%s.tar.gz' % (self.model_name,
                                                     self.dataset_name)
Exemplo n.º 2
0
def main(unused_argv):
    """Main entry.

    Args:
    * unused_argv: unused arguments (after FLAGS is parsed)
    """

    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.enbl_multi_gpu:
        mgw.init()

    trainer = Trainer(data_path=FLAGS.data_path, netcfg=FLAGS.net_cfg)

    trainer.build_graph(is_train=True)
    trainer.build_graph(is_train=False)

    if FLAGS.eval_only:
        trainer.eval()
    else:
        trainer.train()