예제 #1
0
        def TrainerBuilds(self, model):
            tmpdir = os.path.join(FLAGS.test_tmpdir, model.__name__)

            # Trainer should probably create these directories in the future.
            tf.io.gfile.makedirs(os.path.join(tmpdir, 'train'))

            model_params = model()
            cfg = model_params.Model()
            cfg.input = model_params.GetDatasetParams('Train')
            cfg.cluster.mode = 'sync'
            cfg.cluster.job = 'trainer_client'
            _ = trainer_lib.Trainer(cfg, '', tmpdir, tf_master='')
예제 #2
0
 def _CreateTrainer(self, cfg):
     return trainer.Trainer(cfg, FLAGS.model_task_name, FLAGS.logdir,
                            FLAGS.tf_master, self._trial)