def init_reversible_blocks(blocks, loss_layer, input_signature, rng): """Initialize reversible blocks and the loss layer and place weights on CPU. Args: blocks: List of reversible blocks (pairs of layer lists). loss_layer: The final loss layer to initialize. input_signature: The signature of the input to the blocks. rng: Random key used to initialize the layers. """ sig_stack = input_signature process = psutil.Process(os.getpid()) mem_use = process.memory_info().rss for (std_layers, rev_layers) in blocks: rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) rng = rngs[0] for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): sig = cb.inputs_from_stack(sig_stack, layer.n_in) layer.init(sig, rng=layer_rng) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory layer.state = tl.on_cpu(layer.state) # store weights in cpu memory logging.info('init: layer %s\nadded cpu memory (MB): %.2f', str(layer), (process.memory_info().rss - mem_use) / float(1024 * 1024)) mem_use = process.memory_info().rss logging.info('init: cpu memory use (MB): %.2f', mem_use / float(1024 * 1024)) out_sig = layer.output_signature(sig) sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) loss_layer.weights = tl.on_cpu(loss_layer.weights) loss_layer.state = tl.on_cpu(loss_layer.state)
def new_rng(self): """Returns a new single-use random number generator (JAX PRNG key).""" self._rng, rng = fastmath.random.split(self._rng) if self._use_memory_efficient_trainer: self._rng = tl.on_cpu(self._rng) rng = tl.on_cpu(rng) return rng
def load_checkpoint(self, directory=None, filename=None): """Loads model weights and step from a checkpoint on disk. Args: directory: Directory with the checkpoint (self._output_dir by default). filename: Checkpoint file name (model.pkl.gz by default). """ directory = directory or self._output_dir if directory is None: _log('Not loading as both directory and output_dir are None.', stdout=False) return filename = filename or 'model.pkl.gz' path = os.path.join(directory, filename) if not tf.io.gfile.exists(path): _log(f'Not loading as checkpoint file does not exist: {path}.', stdout=False) return d = unpickle_from_file(path, gzip=True) # For large models, load weights from sharded files. if self._use_memory_efficient_trainer: weights = [] n_shards = d['flat_weights'] # We store the number of shards in d here. for i in range(n_shards): w = unpickle_from_file(path + '.shard%d' % i, gzip=True) w = self._from_bits(w) # bit-casting may put w on accelerator, go back weights.extend([tl.on_cpu(x) for x in w]) d['flat_weights'] = weights else: d['flat_weights'] = self._from_bits(d['flat_weights']) self._step = d['step'] if 'slots' in d: if len(self._tasks) != 1: raise ValueError( 'Can\'t load a single-task checkpoint into a multitask Loop.' ) d['slots_per_task'] = [d['slots']] if self._use_memory_efficient_trainer: for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): trainer.slots = slots else: for (task, slots) in zip(self._tasks, d['slots_per_task']): task.optimizer.slots = slots # This is self._model.init_from_file but optimized to not re-read. input_signature = d['input_signature'] weights_and_state_sig = self._model.weights_and_state_signature( input_signature) weights, state = tl.unflatten_weights_and_state( d['flat_weights'], d['flat_state'], weights_and_state_sig) self._model.state = state self._model.weights = weights self._eval_model.weights = self._model.weights # Restore eval model state; note: it's not always the same as train state. if 'flat_eval_state' in d: flat_eval_state = d['flat_eval_state'] else: # It wasn't saved in old checkpoints; remove this branch once ported. flat_eval_state = d['flat_state'] _, eval_state = tl.unflatten_weights_and_state( d['flat_weights'], flat_eval_state, weights_and_state_sig) self._eval_model.state = eval_state
def _init_evaluator(self, eval_task): """Initializes the per-task evaluator.""" model_with_metrics = _model_with_metrics( self._eval_model, eval_task) if self._use_memory_efficient_trainer: return _Evaluator( weights=tl.on_cpu(model_with_metrics.weights[1]), state=tl.on_cpu(model_with_metrics.state[1]), metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0) ) else: return _Evaluator( # Replicate the eval part of weights and state. weights=self._for_n_devices(model_with_metrics.weights[1]), state=self._for_n_devices(model_with_metrics.state[1]), metrics_fn=_accelerate_model_with_metrics( model_with_metrics, self.n_devices) )
def _replicate_cpu(self, x): # TODO(lukaszkaiser): move it to layers/acceleration to be together with # tl.for_n_devices and other functions like that, possibly refactor them. def f(x): if self._n_devices > 1: return np.broadcast_to(x, (self._n_devices, ) + np.asarray(x).shape) else: return x return tl.on_cpu(fastmath.nested_map(f, x))
def init_reversible_blocks(blocks, loss_layer, input_signature, rng): """Initialize reversible blocks and the loss layer and place weights on CPU. Args: blocks: List of reversible blocks (pairs of layer lists). loss_layer: The final loss layer to initialize. input_signature: The signature of the input to the blocks. rng: Random key used to initialize the layers. """ sig_stack = input_signature for (std_layers, rev_layers) in blocks: rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) rng = rngs[0] for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): sig = cb.inputs_from_stack(sig_stack, layer.n_in) layer.init(sig, rng=layer_rng) layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory out_sig = layer.output_signature(sig) sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) loss_layer.weights = tl.on_cpu(loss_layer.weights)
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.global_device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16 * 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # Set to true to run again, e.g., for profiling. run_twice = False if run_twice: t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_run_reversible_weights_trainsfer_xprof(self): """Runs the reversible trainer and profiles weight transfer stats.""" run_this_test = False # We only run this test manually. if not run_this_test or fastmath.global_device_count( ) == 1: # TPU only return # Create inputs and rngs. inputs_batch = np.ones((1024, 128), dtype=np.int32) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 6 rev_layers = [] int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) shape = shapes.ShapeDtype((1024, 128, 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.SGD # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # We profile here. t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 20 # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16*1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16*1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( first_layer, rev_layers, loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step)
def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): self._model_with_loss = model_with_loss self._optimizer = optimizer self._n_devices = n_devices or fastmath.local_device_count() self._adasum = adasum # optimizer slots and opt_params may need to be replicated self._slots, self._opt_params = tl.on_cpu( tl.for_n_devices( (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) # accelerated version of model+loss to replicate weights and state self._accelerated_model_with_loss = tl.Accelerate(model_with_loss, n_devices=n_devices) # Signature: # (batch, weights, state, rng) -> ((loss, state), gradients) self._forward_and_backward_fn = ( fastmath.value_and_grad( model_with_loss.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True)) # return (loss, state), gradients # Signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats self._accelerated_update_fn = (_accelerate_update_fn( self._forward_and_backward_fn, self._optimizer, n_devices=self._n_devices, accelerate=True, adasum=self._adasum))
def _unreplicate(self, x): if self._n_devices == 1: return tl.on_cpu(x) return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x))
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params[ 'learning_rate'] = self._replicate_cpu(learning_rate) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = self._replicate_cpu(learning_rate) for op in rev_ops: op['learning_rate'] = self._replicate_cpu(learning_rate) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. step_int = step if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) step = np.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # JIT the function and run it on CPU to avoid memory fragmentation. rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1:rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. if self._do_free: self._free_accelerators() process = psutil.Process(os.getpid()) if isinstance(batch, (list, tuple)): batch_shapes = [x.shape for x in batch] else: batch_shapes = batch.shape logging.info('running step %d on shapes %s', step_int, str(batch_shapes)) if step_int % self._n_steps_per_log == 1: logging.info('run fwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[ i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng, step_int) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int) block_inputs_states.append( tl.on_cpu(((std_inputs, std_state), (rev_old_states, rev_new_states)))) # Run the loss layer forward and backward with optimizer update. if step_int % self._n_steps_per_log == 1: logging.info('run loss: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) self._collect_weights(self._loss_layer) stats = [tl.on_cpu(loss_stats)] # De-fragment memory. if self._do_free: stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) self._free_accelerators() # Run the layers backward and run optimizer updates. if step_int % self._n_steps_per_log == 1: logging.info('run bwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[ i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(tl.on_cpu(new_stats)) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(tl.on_cpu(std_layer_stats)) # Collect lazily unreplicated layer weights. for rev_layer_id in range(self._n_async_layers): self._collect_weights(rev_layers[rev_layer_id]) self._collect_weights(std_layer) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats
def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) opt.slots = tl.on_cpu(opt.slots) return opt
def slots(self, slots): """Sets the slots of the optimizers and this class (replicated).""" self._optimizer.slots = slots self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices))
def __init__( self, model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, callbacks=None, ): """Configures a training ``Loop``, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. tasks: List of :py:class:`TrainTask` instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. It can also be a single :py:class:`TrainTask` instance which is treated in the same way as a singleton list. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If ``None``, the training model (model) will be used. eval_tasks: List of :py:class:`EvalTask` instances which define how to evaluate the model: which validation data to use and which metrics to report. Evaluation on each of the tasks and will run and be reported separately which allows to score a model on different subtasks. This argument can also be ``None``, in which case no evals will be run, or a single :py:class:`EvalTask`, which wil be treated in the same way as a singleton list. output_dir: Path telling where to save outputs (evals and checkpoints). Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are ``None``. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If ``None``, the default is periodic checkpointing at ``task.n_steps_per_checkpoint``. permanent_checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved permanently. If ``None``, the default is periodic checkpointing at ``task.n_steps_per_permanent_checkpoint``. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If ``None``, run when checkpointing. which_task: Function (integer --> integer) indicating which task should be used at which training step. Can be set to ``None`` in single-task training. n_devices: integer or ``None``, the number of devices for this computation. random_seed: the random seed to use; time/os dependent if ``None`` (default). loss_chunk_size: int, if > 0 use chunks of this size to make loss computation more more memory-efficient. use_memory_efficient_trainer: whether to use a special memory-efficient trainer; if set to 2, the memory efficiency if very aggressive callbacks: List of subclasses of StepCallback to call on training steps. """ self._is_chief, self._n_hosts, self._n_devices, self._rng = ( init_host_and_devices(n_devices, random_seed)) if use_memory_efficient_trainer: self._rng = tl.on_cpu(self._rng) # Handle single task case without lists too. if not isinstance(tasks, (list, tuple)): tasks = [tasks] if not tasks: raise ValueError('Must provide at least one training task.') if eval_tasks is None: eval_tasks = [] eval_at = _never else: if not isinstance(eval_tasks, (list, tuple)): eval_tasks = [eval_tasks] self._tasks = tasks self._model = model self._eval_model = eval_model or model self._use_memory_efficient_trainer = use_memory_efficient_trainer self._loss_chunk_size = loss_chunk_size # TODO(lukaszkaiser): can we have different eval models and save memory? if use_memory_efficient_trainer: assert len(tasks) == 1, 'only single task supported for now' self._eval_model = model default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) permanent_default_at = _at_step_1_and_every_nth_step( tasks[0].n_steps_per_permanent_checkpoint) if output_dir is not None: self._output_dir = os.path.expanduser(output_dir) tf.io.gfile.makedirs(self._output_dir) inputs.load_data_counters(self._output_dir) else: self._output_dir = None # Prepare training components. self._step = 0 self._history = trax_history.History() self._checkpoint_at = checkpoint_at or default_at self._permanent_checkpoint_at = ( permanent_checkpoint_at or permanent_default_at) if which_task is None: # If which task is not passed, then we permute tasks one by one. # If len(tasks) = 1, then which_task is a constant function equal to 0. which_task = lambda n: n % len(tasks) self._which_task = which_task # Initialize using the given random seed. # NOTE: If random_seed is None then self._rng will be different on # different hosts, leading to different weights on the different hosts. self._batch_signature = shapes.signature(tasks[0].sample_batch) self._model.rng = self.new_rng() # In the memory-efficient case, we initialize in init_trainer. if not use_memory_efficient_trainer: if _is_uninitialized(self._model): self._model.init(self._batch_signature) self._eval_model.rng = self.new_rng() if _is_uninitialized(self._eval_model): self._eval_model.init(self._batch_signature) # To handle the above case (i.e. random_seed = None), we psum the weights # and state and average them. # NOTE: This adds time (how much?) so we prefer not to do it if it is # unnecessary, i.e. random_seed was set. # NOTE: Averaging the weights across devices can screw up the initial weight # statistics. # TODO(pkozakowski): Broadcast from one of the devices instead? # TODO(lukaszkaiser): make it work for the memory-efficient trainer too. if (random_seed is None and self._n_hosts > 1 and not use_memory_efficient_trainer): logging.info('Syncing weights/state across %d hosts.', self._n_hosts) self._sync_weights_and_state_across_hosts() # Create the optimizer for the training loss function. self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) self.load_checkpoint() # Prepare eval components. self._eval_at = eval_at or default_at self._eval_tasks = eval_tasks loss_names = [task.loss_name for task in self._tasks] metric_names = [ name # pylint: disable=g-complex-comprehension for eval_task in self._eval_tasks for name in eval_task.metric_names ] self._rjust_len = max(map(len, loss_names + metric_names)) self._evaluator_per_task = tuple( self._init_evaluator(eval_task) for eval_task in self._eval_tasks) if self._output_dir is None: _log('Will not write evaluation metrics, because output_dir is None.') def task_output_dir(task_index, task_list): if self._output_dir is not None: if len(task_list) < 2: output_dir = self._output_dir else: output_dir = os.path.join(self._output_dir, str(task_index)) tf.io.gfile.makedirs(output_dir) return output_dir else: return None self._output_dir_per_eval_task = [ task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))] self._output_dir_per_train_task = [ task_output_dir(i, tasks) for i in range(len(tasks))] callbacks = callbacks or [] self._callbacks = [ callback_class(self) for callback_class in callbacks ]