def test_batch_size(self): self.assertEquals( distribution_utils.per_replica_batch_size(147, num_gpus=0), 147) self.assertEquals( distribution_utils.per_replica_batch_size(147, num_gpus=1), 147) self.assertEquals( distribution_utils.per_replica_batch_size(147, num_gpus=7), 21)
def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj))
def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, input_context=input_context)
def test_batch_size_with_remainder(self): with self.assertRaises(ValueError): distribution_utils.per_replica_batch_size(147, num_gpus=5)