def init(): logging_config.setup_logging(logger, tnt.global_tnt_config.log_level, tnt.get_rank(), tnt.is_master_rank(), tnt.global_tnt_config.log_on_all_devices) # the number of GPUs per node can be specified either as default # configuration value or a `TNT_GPUS_PER_NODE` environment variable devices_per_node = tnt.global_tnt_config.gpus_per_node setup_gpus(tnt.get_rank(), ngpus=devices_per_node)
def __init__(self, connection_table, num_micro_batches): atexit.register(self.close) rank = tnt.get_rank() self.local_edge_list = extract_local_edges(connection_table, rank) self.num_micro_batches = num_micro_batches self.pipeline_comm = GPICommLib.PipelineCommunicator(self.local_edge_list, self.num_micro_batches)
def __init__(self, model): if not tarantella.global_context: raise RuntimeError( """Cannot initialize a Model before the Tarantella library. Please call "tarantella.init()" first. """) self.rank = tarantella.get_rank() self.comm_size = tarantella.get_size() self.model = model self.input_shapes = None self.done_broadcast = False self.compiled = False self.broadcaster = None self.barrier = tarantella.Barrier() self.orig_optimizer = None self.orig_loss = None self.orig_metrics = None self.orig_loss_weights = None self.orig_sample_weight_mode = None self.orig_weighted_metrics = None self.dist_optimizer = None self.default_shuffle_seed = 42 # support for TF 2.0 -- 2.3 self.tf_default_verbose = { 'fit': 1, 'evaluate': 1, 'predict': 0, }
def train_tnt_and_reference_models(model_config, optimizer, micro_batch_size, nbatches, number_epochs, optimizer_kwargs={}): (train_dataset, _) = util.train_test_mnist_datasets(nbatches=nbatches, micro_batch_size=micro_batch_size) (ref_train_dataset, _) = util.train_test_mnist_datasets(nbatches=nbatches, micro_batch_size=micro_batch_size) tnt_model_runner, ref_model_runner = get_compiled_models( model_config, optimizer, **optimizer_kwargs) tnt_history = tnt_model_runner.train_model(train_dataset, number_epochs) ref_history = ref_model_runner.train_model(ref_train_dataset, number_epochs) rank = tnt.get_rank() logging.getLogger().info(f"[Rank {rank}] Tarantella (loss, accuracy) = " f"({tnt_history.history})") logging.getLogger().info(f"[Rank {rank}] Reference (loss, accuracy) = " f"({ref_history.history})") return tnt_history, ref_history
def to_microbatched(model, micro_batch_size, num_micro_batches, num_batches, num_test_batches): rank = tnt.get_rank() partition_generator = pgen.GraphPartitionGenerator(model) rank_mapper = rmapper.RankMapper(num_ranks = tnt.get_size(), pipeline_graph = partition_generator.get_pipeline_graph()) partition_id = rank_mapper.get_partition_for_rank(rank) partition_graph = partition_generator.get_partition_graph(partition_id) partition_info = pinfo.PartitionInfo(partition_id = partition_id, partition_graph = partition_graph) core_model_builder = cm_builder.CoreModelBuilder(model, partition_id, partition_graph) core_model = core_model_builder.get_model() connection_table = rank_mapper.get_connections_for_rank(rank) pipeline_communicator = tnt.PipelineCommunicator(connection_table, num_micro_batches) shared_model_builder = shared.SharedModelBuilder(partition_info, core_model, pipeline_communicator, micro_batch_size) shared_model = shared_model_builder.get_model() microbatched_model_builder = microbatched.MicrobatchedModelBuilder(partition_info, shared_model, micro_batch_size, num_micro_batches) ds = load_microbatched_datasets(micro_batch_size, num_micro_batches, num_batches, num_test_batches, partition_info) pipeline_communicator.setup_infrastructure(micro_batch_size) return microbatched_model_builder, ds
def test_single_value(self): inputs = float(tnt.get_rank()) expected_output = sum(range(tnt.get_size())) allreducer = tnt.TensorAllreducer(inputs) output = allreducer.allreduce(inputs) assert isinstance(output, float) assert expected_output == output
def train_and_eval(self): """Trains the model.""" lr_schedule = optimizer.LearningRateSchedule(self.params["learning_rate"], self.params["hidden_size"], self.params["learning_rate_warmup_steps"]) opt = tf.keras.optimizers.Adam(lr_schedule, self.params["optimizer_adam_beta1"], self.params["optimizer_adam_beta2"], epsilon=self.params["optimizer_adam_epsilon"]) self.train_model.compile(opt) self.train_model.summary() # create train dataset train_ds = data_pipeline.train_input_fn(self.params, shuffle_seed = 42, num_ranks = tnt.get_size(), rank = tnt.get_rank()) # enable global callbacks callbacks = [] if self.flags_obj.enable_tensorboard and self.flags_obj.model_dir: callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=self.flags_obj.model_dir)) # enable logging callbacks only on the master rank if self.flags_obj.enable_time_history: time_callback = keras_utils.TimeHistory(self.params["batch_size"], self.params["num_sentences"], logdir = None) tnt_time_callback = tnt.keras.callbacks.Callback(time_callback, aggregate_logs = False, run_on_all_ranks = False) callbacks.append(tnt_time_callback) # print messages only once if tnt.is_master_rank(): logging.info("Start train") stats = {} for epoch in range(0, self.params["train_epochs"], self.params["epochs_between_evals"]): # as our dataset is distributed manually, disable the automatic Tarantella distribution history = self.train_model.fit(train_ds, callbacks = callbacks, tnt_distribute_dataset = False, initial_epoch = epoch, epochs = epoch + min(self.params["epochs_between_evals"], self.params["train_epochs"]-epoch), verbose = 2) if tnt.is_master_rank(): logging.info("Train history: {}".format(history.history)) stats = misc.build_stats(history, callbacks) if tnt.is_master_rank(): eval_stats = self.eval() stats.update(eval_stats) return stats
def test_array_inf(self, array_length, index): injection_rank = util.same_random_int_all_ranks(0, tnt.get_size()) input_array = np.ones(shape=(array_length, 1), dtype=np.float32) if tnt.get_rank() == injection_rank: input_array[index] = math.inf allreducer = tnt.TensorAllreducer(input_array) output_array = allreducer.allreduce(input_array) assert np.isinf(output_array[index])
def test_single_array_different_inputs(self, array_length): input_array = np.empty(shape=(array_length, 1), dtype=np.float32) input_array.fill(tnt.get_rank()) expected_output_array = np.empty(input_array.shape, dtype=np.float32) expected_output_array.fill(sum(range(tnt.get_size()))) allreducer = tnt.TensorAllreducer(input_array) output_array = allreducer.allreduce(input_array) assert isinstance(output_array, np.ndarray) assert np.array_equal(output_array, expected_output_array)
def save_setup(request): save_all_devices = request.param # save model in a shared directory accessible to all ranks save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_save_model") if save_all_devices: save_dir = save_dir + str(tnt.get_rank()) yield {'save_dir': save_dir, 'all_devices': save_all_devices} # clean up if save_all_devices or tnt.is_master_rank(): shutil.rmtree(save_dir, ignore_errors=True)
def _create_tnt_model(cls, model: tf.keras.Model, parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \ else tnt.ParallelStrategy.DATA, num_pipeline_stages: int = 1): replica_group = tnt.Group() if (tnt.ParallelStrategy.PIPELINING in parallel_strategy) and isinstance(model, tf.keras.Sequential): logger.warn( f"Cannot pipeline a `tf.keras.Sequential` model; disabling model parallelism." ) parallel_strategy = parallel_strategy ^ tnt.ParallelStrategy.PIPELINING logger.info(f"Creating parallel model using {parallel_strategy}.") if tnt.ParallelStrategy.PIPELINING in parallel_strategy: rank = tnt.get_rank() partition_generator = pgen.GraphPartitionGenerator(model) rank_mapper = rmapper.RankMapper( num_ranks=tnt.get_size(), pipeline_graph=partition_generator.get_pipeline_graph()) pipeline_group = rank_mapper.get_pipelining_group_for_rank(rank) logger.info( f"[Pipelining] Creating pipelined model with {pipeline_group.size} partitions." ) # get my partition model = pm.PartitionedModel( model=model, group=pipeline_group, partition_generator=partition_generator, rank_mapper=rank_mapper, num_pipeline_stages=num_pipeline_stages) if tnt.ParallelStrategy.DATA in parallel_strategy: replica_group = rank_mapper.get_replica_group_for_rank(rank) else: if pipeline_group.size != tnt.get_size(): raise ValueError( f"Provided model has only {pipeline_group.size} partitions; use {pipeline_group.size} ranks or a different parallel strategy." ) if tnt.ParallelStrategy.DATA in parallel_strategy: # replicate my partition across the data parallel group logger.info( f"[DataParallel] Replicating local model across ranks {replica_group.group}." ) model = dpm.DataParallelModel(model=model, group=replica_group) return model
def test_single_array(self, array_shape): np.random.seed(42) input_array = np.random.random_sample(array_shape).astype('float32') rank = tnt.get_rank() root_rank = tnt.get_size() - 1 broadcaster = tnt.TensorBroadcaster(input_array, root_rank) expected_output_array = input_array if rank == root_rank: output_array = broadcaster.broadcast(input_array) else: output_array = broadcaster.broadcast() result = (output_array == expected_output_array).all() assert isinstance(output_array, np.ndarray) assert result
def broadcast(self, inputs=None): outputs = list() for i, bcast in enumerate(self.broadcasts): if tnt.get_rank() == self.root_global_rank: if utils.is_nonEmptyArray(inputs): inputs = [inputs] elif not utils.is_nonEmptyList(inputs): self._raise_input_error() assert len(self.broadcasts) == len(inputs) bcast.start(inputs[i]) else: bcast.start() for i, bcast in enumerate(self.broadcasts): out = bcast.wait_for_completion() outputs.append(out.reshape(self.shapes[i])) return outputs if len(outputs) > 1 else outputs[0]
def test_compare_accuracy_against_reference(self, model_runners, micro_batch_size, number_epochs, nbatches, test_nbatches, remainder_samples_per_batch, last_incomplete_batch_size): (train_dataset, test_dataset) = util.train_test_mnist_datasets( nbatches=nbatches, test_nbatches=test_nbatches, micro_batch_size=micro_batch_size, shuffle=False, remainder_samples_per_batch=remainder_samples_per_batch, last_incomplete_batch_size=last_incomplete_batch_size) (ref_train_dataset, ref_test_dataset) = util.train_test_mnist_datasets( nbatches=nbatches, test_nbatches=test_nbatches, micro_batch_size=micro_batch_size, shuffle=False, remainder_samples_per_batch=remainder_samples_per_batch, last_incomplete_batch_size=last_incomplete_batch_size) tnt_model_runner, reference_model_runner = model_runners reference_model_runner.train_model(ref_train_dataset, number_epochs) tnt_model_runner.train_model(train_dataset, number_epochs) tnt_loss_accuracy = tnt_model_runner.evaluate_model(test_dataset) ref_loss_accuracy = reference_model_runner.evaluate_model( ref_test_dataset) rank = tnt.get_rank() logging.getLogger().info( f"[Rank {rank}] Tarantella[loss, accuracy] = {tnt_loss_accuracy}") logging.getLogger().info( f"[Rank {rank}] Reference [loss, accuracy] = {ref_loss_accuracy}") result = [True, True] if tnt.is_master_rank(): result = [ np.isclose(tnt_loss_accuracy[0], ref_loss_accuracy[0], atol=1e-2), # losses might not be identical np.isclose(tnt_loss_accuracy[1], ref_loss_accuracy[1], atol=1e-6) ] util.assert_on_all_ranks(result)
def test_list_of_arrays(self, array_shape, list_length): np.random.seed(42) input_array = np.random.random_sample(array_shape).astype('float32') inputs = list_length * [input_array] rank = tnt.get_rank() root_rank = 0 broadcaster = tnt.TensorBroadcaster(inputs, root_rank) expected_output_array = input_array if rank == root_rank: outputs = broadcaster.broadcast(inputs) else: outputs = broadcaster.broadcast() result = all( (array == expected_output_array).all() for array in outputs) assert isinstance(outputs, list) assert result
def __init__(self, dataset, num_ranks=tnt.get_size(), rank=tnt.get_rank(), shuffle_seed=42): self.num_ranks = num_ranks self.rank = rank self.shuffle_seed = shuffle_seed self.base_dataset, self.dataset_transformations = \ ops_helpers.gen_dataset_transformations(dataset) self.batching_info = ops_helpers.get_batching_info( self.dataset_transformations) # convenience attributes computed when the dataset is distributed among ranks self._dataset = None self._num_samples = None self._micro_batch_size = None
def test_train(self, model_generator, num_micro_batches, micro_batch_size, num_batches, num_test_batches, number_epochs): batch_size = micro_batch_size * num_micro_batches fit_params = {'epochs' : number_epochs, 'shuffle' : False, 'verbose' : 0} rank = tnt.get_rank() master_rank = tnt.get_size() - 1 # the last partition will be assigned to rank (nranks-1) # reference model if rank == master_rank: reference_ds = load_reference_datasets(batch_size, num_batches, num_test_batches) reference_model = model_generator() reference_model.compile(**get_reference_compile_params()) reference_history = reference_model.fit(reference_ds["train"], validation_data = reference_ds["val"], **fit_params) reference_result = reference_model.evaluate(reference_ds["test"], verbose = 0) # pipelined model model = model_generator() microbatched_model_builder, microbatched_ds = to_microbatched(model, micro_batch_size, num_micro_batches, num_batches, num_test_batches) microbatched_model = microbatched_model_builder.get_model() microbatched_model.summary() microbatched_model.compile(**get_microbatched_compile_params(microbatched_model_builder)) pipeline_history = microbatched_model.fit(microbatched_ds["train"], validation_data = microbatched_ds["val"], **fit_params) pipeline_result = microbatched_model.evaluate(microbatched_ds["test"], verbose = 0) if rank == master_rank: print (reference_history.history) print (pipeline_history.history) check_histories_match(reference_history, pipeline_history, num_micro_batches) check_validation_histories_match(reference_history, pipeline_history, num_micro_batches) check_predictions_match(reference_result, pipeline_result, num_micro_batches)
def test_send_all_connections(self, partition, num_micro_batches): elem_type = np.dtype(np.float32) pipeline_comm = tnt.PipelineCommunicator(partition, num_micro_batches) micro_batch_size = 4 pipeline_comm.setup_infrastructure(micro_batch_size) # send on all connections for micro_batch_id in range(num_micro_batches): for conn_id in pipeline_comm.get_local_connection_ids(): conn_info = partition[conn_id] if conn_info.get_other_rank(tnt.get_rank()) < tnt.get_rank(): continue array_length = micro_batch_size * conn_info.get_size_in_bytes( ) // elem_type.itemsize input_array = np.empty(shape=(array_length, 1), dtype=elem_type) input_array.fill(tnt.get_rank()) pipeline_comm.send(input_array, connection_id=conn_id, micro_batch_id=micro_batch_id) # receive on all connections for micro_batch_id in range(num_micro_batches): for conn_id in pipeline_comm.get_local_connection_ids(): conn_info = partition[conn_id] if conn_info.get_other_rank(tnt.get_rank()) > tnt.get_rank(): continue array_length = micro_batch_size * conn_info.get_size_in_bytes( ) // elem_type.itemsize expected_array = np.empty(shape=(array_length, 1), dtype=elem_type) expected_array.fill(conn_info.get_other_rank(tnt.get_rank())) input_array = np.empty(shape=(array_length, 1)) result = pipeline_comm.recv(input_array, connection_id=conn_id, micro_batch_id=micro_batch_id) assert np.allclose(result, expected_array, atol=1e-6)
fc_units = 20 num_mnist_classes = 10 shuffle_seed = 17 learning_rate = 0.01 elem_type = np.dtype(np.float32) number_connections = 2 number_partitions = 2 p_0_id = 0 p_1_id = 1 # Results correctness is checked on the `master_rank`, which has to be on rank=0 to be able to # forward the test exit code to `gaspi_run` p_0_rank = 1 p_1_rank = 0 master_rank = p_1_rank rank = tnt.get_rank() def get_reference_model(): util.set_tf_random_seed() reference_input = keras.Input(shape=(28,28,1,), name='reference_input') reference_x = layers.Flatten()(reference_input) reference_x = layers.Dense(fc_units, activation='relu', name='dense_relu')(reference_x) reference_output = layers.Dense(num_mnist_classes, activation='softmax', name='dense_softmax')(reference_x + reference_x) reference_model = keras.Model(inputs=reference_input, outputs=reference_output, name="reference_model") return reference_model def get_partitioned_core_model(): # --- core model on partition 0
def __init__(self, model, group = tnt.Group()): super().__init__() self.rank = tnt.get_rank() self.group = group self.model = model atexit.register(self.close)
parser = ArgumentParser() parser.add_argument('--hpdlf', dest='use_hpdlf', action='store_true', default=False) args = parser.parse_args() use_hpdlf = args.use_hpdlf print(use_hpdlf) if use_hpdlf: import tarantella tarantella.init(0) rank = tarantella.get_rank() comm_size = tarantella.get_size() else: rank = 0 comm_size = 1 print(('RANK: {}\n' 'COMM_SIZE: {}').format(rank, comm_size)) model = ReversibleSequential(*IN_SHAPE) for k in range(5): kwargs = { 'affine_clamping': 1.0, 'global_affine_init': 0.85, 'global_affine_type': 'SOFTPLUS', 'subnet_constructor': subnet.SubnetFactory(64)
def main(_): flags_obj = flags.FLAGS # get rank and comm_size rank = tnt.get_rank() comm_size = tnt.get_size() # compute micro batch if the dataset is not automatically distributed by Tarantella if not flags_obj.auto_distributed: batch_size = flags_obj.batch_size // comm_size else: batch_size = flags_obj.batch_size # Load and preprocess datasets (train_dataset, validation_dataset, _) = dataset_utils.get_tnt_cifar10_dataset(45000, 5000, 10000, batch_size) # Create model and wrap it into a Tarantella model model = resnet_model.resnet32(num_classes=10) model = tnt.Model(model) optimizer = get_optimizer(flags_obj.batch_size) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=(['sparse_categorical_accuracy'])) model.summary() callbacks = [] if flags_obj.enable_tensorboard: callbacks.append( tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir, profile_batch=2)) if flags_obj.profile_runtime: callbacks.append( RuntimeProfiler(batch_size=batch_size, logging_freq=flags_obj.logging_freq, print_freq=flags_obj.print_freq)) if flags_obj.enable_checkpoint_and_export: if flags_obj.model_dir is not None: ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks.append( tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) logging.info("Start training") kwargs = { 'tnt_distribute_dataset': flags_obj.auto_distributed, 'tnt_distribute_validation_dataset': flags_obj.auto_distributed } history = model.fit(train_dataset, epochs=flags_obj.train_epochs, callbacks=callbacks, validation_data=validation_dataset, validation_freq=flags_obj.epochs_between_evals, verbose=flags_obj.verbose, **kwargs) logging.info("Train history: {}".format(history.history)) kwargs = {'tnt_distribute_dataset': flags_obj.auto_distributed} eval_output = model.evaluate(validation_dataset, verbose=flags_obj.verbose, **kwargs)
def test_send_recv_layers_forward_and_backward(self, test_case): rank_0 = 0 rank_1 = 1 rank = tnt.get_rank() number_tags = 2 # mbatch_id, connection_id elem_type = np.dtype(np.float32) number_epochs = 3 connection_id = test_case["connection_id"] micro_batch_size = test_case["micro_batch_size"] num_micro_batches = test_case["num_micro_batches"] data_to_be_sent = test_case["data_to_be_sent"] tensor_size = np.array(data_to_be_sent[0]).size tags_dataset = [] for mbatch_id in range(num_micro_batches): tags_dataset = tags_dataset + micro_batch_size * [[ mbatch_id, connection_id ]] labels_dataset = micro_batch_size * num_micro_batches * [0.] connection_table = { connection_id: cinfo.ConnectionInfo((rank_0, rank_1), tensor_size * elem_type.itemsize) } pipeline_comm = tnt.PipelineCommunicator(connection_table, num_micro_batches) pipeline_comm.setup_infrastructure(micro_batch_size) if rank == rank_0: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) input_seq = keras.Input(shape=(1, ), name='input_seq') inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RemoveSeqInput()( [inputs, input_seq]) # force execution of backward pass outputs = tnt_layers.SendLayer( pipeline_communicator=pipeline_comm)(outputs, input_tags) outputs = tnt_layers.IdentityLayer(name="result")(outputs) model = keras.Model([inputs, input_tags, input_seq], outputs) loss = tnt_losses.ZeroLoss() dataset = data_to_be_sent if rank == rank_1: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) input_seq = keras.Input(shape=(1, ), name='input_seq') inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RemoveSeqInput()([inputs, input_seq]) outputs = tnt_layers.RecvLayer( pipeline_communicator=pipeline_comm)(outputs, input_tags) outputs = tnt_layers.IdentityLayer(name="result")(outputs) model = keras.Model([inputs, input_tags, input_seq], outputs) # loss = PlusOneLoss() loss = tnt_losses.ZeroLoss() dataset = np.zeros_like(data_to_be_sent) def generator(): input_seq_constant = 0 for data, tag, label in zip(dataset, tags_dataset, labels_dataset): yield ({ "input": data, "tags": tag, "input_seq": input_seq_constant }, { "result": label }) final_dataset = tf.data.Dataset.from_generator(generator, output_types=({ "input": tf.float32, "tags": tf.int32, "input_seq": tf.float32 }, { "result": tf.float32 })) final_dataset = final_dataset.batch(micro_batch_size) model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.1), loss=loss) history = model.fit(final_dataset, epochs=number_epochs) assert len(history.history['loss']) == number_epochs
def test_send_recv_layers_forward(self, test_case): rank_0 = 0 rank_1 = 1 rank = tnt.get_rank() number_tags = 2 # mbatch_id, connection_id elem_type = np.dtype(np.float32) connection_id = test_case["connection_id"] micro_batch_size = test_case["micro_batch_size"] num_micro_batches = test_case["num_micro_batches"] data_to_be_sent = test_case["data_to_be_sent"] tensor_size = np.array(data_to_be_sent[0]).size tags_dataset = [] for mbatch_id in range(num_micro_batches): tags_dataset = tags_dataset + micro_batch_size * [[ mbatch_id, connection_id ]] connection_table = { connection_id: cinfo.ConnectionInfo((rank_0, rank_1), tensor_size * elem_type.itemsize) } pipeline_comm = tnt.PipelineCommunicator(connection_table, num_micro_batches) pipeline_comm.setup_infrastructure(micro_batch_size) if rank == rank_0: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.SendLayer( pipeline_communicator=pipeline_comm)(inputs, input_tags) model = keras.Model([inputs, input_tags], outputs) dataset = data_to_be_sent if rank == rank_1: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RecvLayer( pipeline_communicator=pipeline_comm)(inputs, input_tags) model = keras.Model([inputs, input_tags], outputs) dataset = np.zeros_like(data_to_be_sent) def generator(): for data, tag in zip(dataset, tags_dataset): yield {"input": data, "tags": tag} final_dataset = tf.data.Dataset.from_generator(generator, output_types={ "input": tf.float32, "tags": tf.int32 }) data_received = model.predict(final_dataset.batch(micro_batch_size)) if rank == rank_1: assert np.allclose(data_received, data_to_be_sent)
def train(args): use_tarantella = eval(args['training']['use_tarantella']) ndims_tot = np.prod(eval(args['data']['data_dimensions'])) output_dir = args['checkpoints']['output_dir'] sched_milestones = eval(args['training']['milestones_lr_decay']) n_epochs = eval(args['training']['N_epochs']) optimizer_kwargs = eval(args['training']['optimizer_kwargs']) optimizer_type = args['training']['optimizer'] optimizer_lr = eval(args['training']['lr']) if use_tarantella: import tarantella # no argument (otherwise: ranks per node) tarantella.init() node_rank = tarantella.get_rank() nodes_number = tarantella.get_size() else: node_rank = 0 nodes_number = 1 is_primary_node = (node_rank == 0) args['training']['rank'] = repr(node_rank) args['training']['comm_size'] = repr(nodes_number) model = build_model(args) data = Dataset(args) print(f'NODE_RANK {node_rank}') print(f'N_NODES {nodes_number}') print(f'NODE_RANK {str(is_primary_node).upper()}', flush=True) def nll_loss_z_part(y, z): zz = tf.math.reduce_mean(z**2) return 0.5 * zz def nll_loss_jac_part(y, jac): return -tf.math.reduce_mean(jac) / ndims_tot def lr_sched(ep, lr): if ep in sched_milestones: return 0.1 * lr return lr # TODO: should this only be for one node, or for each? lr_scheduler_callback = kr.callbacks.LearningRateScheduler( lr_sched, verbose=is_primary_node) callbacks = [lr_scheduler_callback, kr.callbacks.TerminateOnNaN()] if is_primary_node: #checkpoint_callback = kr.callbacks.ModelCheckpoint(filepath=os.path.join(output_dir, 'checkpoint_best.hdf5'), #save_best_only=True, #save_weights_only=True, #mode='min', #verbose=is_primary_node) loss_log_callback = kr.callbacks.CSVLogger(os.path.join( output_dir, 'losses.dat'), separator=' ') #callbacks.append(checkpoint_callback) callbacks.append(loss_log_callback) try: optimizer_type = { 'ADAM': kr.optimizers.Adam, 'SGD': kr.optimizers.SGD }[optimizer_type] except KeyError: optimizer_type = eval(optimizer_type) optimizer = optimizer_type(optimizer_lr, **optimizer_kwargs) if use_tarantella: model = tarantella.Model(model) model.compile(loss=[nll_loss_z_part, nll_loss_jac_part], optimizer=optimizer, run_eagerly=False) model.build((128, 32, 32, 3)) try: history = model.fit( data.train_dataset, epochs=n_epochs, verbose=is_primary_node, callbacks=callbacks, validation_data=(data.test_dataset if is_primary_node else None)) except: raise