def from_params( # type: ignore cls, serialization_dir: str, params: Params) -> "LogToTensorboard": log_batch_size_period = params.pop_int("log_batch_size_period", None) tensorboard = TensorboardWriter.from_params( params=params, serialization_dir=serialization_dir, get_batch_num_total=lambda: None) return LogToTensorboard(tensorboard, log_batch_size_period)
def from_params( # type: ignore cls, serialization_dir: str, params: Params, **extras) -> "LogToTensorboard": log_batch_size_period = params.pop_int("log_batch_size_period", None) tensorboard = TensorboardWriter.from_params( params=params, serialization_dir=serialization_dir, get_batch_num_total=lambda: None) # TODO(mattg): remove get_batch_num_total from TensorboardWriter, and instead just add a # method / arguments to tell the writer what batch num we're at. return LogToTensorboard(tensorboard, log_batch_size_period)