def train(model_fn, datasets, logdir, config): """Train a model on a datasets. The model function receives the following arguments: data batch, trainer phase, whether it should log, and the config. The configuration object should contain the attributes `batch_shape`, `train_steps`, `test_steps`, `max_steps`, in addition to the attributes expected by the model function. Args: model_fn: Function greating the model graph. datasets: Dictionary with keys `train` and `test` and datasets as values. logdir: Optional logging directory for summaries and checkpoints. config: Configuration object. Yields: Test score of every epoch. Raises: KeyError: if config is falsey. """ if not config: raise KeyError('You must specify a configuration.') logdir = logdir and os.path.expanduser(logdir) try: config = load_config(logdir) except IOError: config = save_config(config, logdir) trainer = trainer_.Trainer(logdir, config=config) with tf.variable_scope('graph', use_resource=True): data = get_batch(datasets, trainer.phase, trainer.reset) score, summary = model_fn(data, trainer, config) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) if config.train_steps: trainer.add_phase( 'train', config.train_steps, score, summary, batch_size=config.batch_shape[0], report_every=None, log_every=config. train_log_every, # Keeps running average logged every... checkpoint_every=config.train_checkpoint_every) if config.test_steps: trainer.add_phase( 'test', config.test_steps, score, summary, batch_size=config.batch_shape[0], report_every=config.test_steps, log_every=config. test_steps, # Keeps running average logged every... checkpoint_every=config.test_checkpoint_every) for saver in config.savers: trainer.add_saver(**saver) for score in trainer.iterate(config.max_steps): yield score
def train(model_fn, datasets, logdir, config): """Train a model on a datasets. The model function receives the following arguments: data batch, trainer phase, whether it should log, and the config. The configuration object should contain the attributes `batch_shape`, `train_steps`, `test_steps`, `max_steps`, in addition to the attributes expected by the model function. Args: model_fn: Function creating the model graph. datasets: Dictionary with keys `train` and `test` and datasets as values. logdir: Optional logging directory for summaries and checkpoints. config: Configuration object. Yields: Test score of every epoch. Raises: KeyError: if config is False. """ if not config: raise KeyError('You must specify a configuration.') logdir = logdir and os.path.expanduser(logdir) try: config = load_config(logdir) except IOError: config = save_config(config, logdir) trainer = trainer_.Trainer(logdir, config=config) with tf.variable_scope('graph', use_resource=True): data = get_batch( datasets, trainer.phase, trainer.reset ) # {'state': <tf.Tensor 'graph/cond_3/Merge_4:0' shape=(50, 50, 1) dtype=float32>, 'image': <tf.Tensor 'graph/cond_3/Merge_1:0' shape=(50, 50, 64, 64, 3) dtype=float32>, 'action': <tf.Tensor 'graph/cond_3/Merge:0' shape=(50, 50, 2) dtype=float32>, 'reward': <tf.Tensor 'graph/cond_3/Merge_3:0' shape=(50, 50) dtype=float32>, 'length': <tf.Tensor 'graph/cond_3/Merge_2:0' shape=(50,) dtype=int32>} score, summary = model_fn(data, trainer, config) # model_fn is training.define_model message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) if config.train_steps: # 50000 trainer.add_phase('train', config.train_steps, score, summary, batch_size=config.batch_shape[0], report_every=None, log_every=config.train_log_every, checkpoint_every=config.train_checkpoint_every) if config.test_steps: # 100 trainer.add_phase('test', config.test_steps, score, summary, batch_size=config.batch_shape[0], report_every=config.test_steps, log_every=config.test_steps, checkpoint_every=config.test_checkpoint_every) for saver in config.savers: trainer.add_saver(**saver) for score in trainer.iterate(config.max_steps): yield score