def to_microbatched(model, micro_batch_size, num_micro_batches, num_batches, num_test_batches): rank = tnt.get_rank() partition_generator = pgen.GraphPartitionGenerator(model) rank_mapper = rmapper.RankMapper(num_ranks = tnt.get_size(), pipeline_graph = partition_generator.get_pipeline_graph()) partition_id = rank_mapper.get_partition_for_rank(rank) partition_graph = partition_generator.get_partition_graph(partition_id) partition_info = pinfo.PartitionInfo(partition_id = partition_id, partition_graph = partition_graph) core_model_builder = cm_builder.CoreModelBuilder(model, partition_id, partition_graph) core_model = core_model_builder.get_model() connection_table = rank_mapper.get_connections_for_rank(rank) pipeline_communicator = tnt.PipelineCommunicator(connection_table, num_micro_batches) shared_model_builder = shared.SharedModelBuilder(partition_info, core_model, pipeline_communicator, micro_batch_size) shared_model = shared_model_builder.get_model() microbatched_model_builder = microbatched.MicrobatchedModelBuilder(partition_info, shared_model, micro_batch_size, num_micro_batches) ds = load_microbatched_datasets(micro_batch_size, num_micro_batches, num_batches, num_test_batches, partition_info) pipeline_communicator.setup_infrastructure(micro_batch_size) return microbatched_model_builder, ds
def test_send_all_connections(self, partition, num_micro_batches): elem_type = np.dtype(np.float32) pipeline_comm = tnt.PipelineCommunicator(partition, num_micro_batches) micro_batch_size = 4 pipeline_comm.setup_infrastructure(micro_batch_size) # send on all connections for micro_batch_id in range(num_micro_batches): for conn_id in pipeline_comm.get_local_connection_ids(): conn_info = partition[conn_id] if conn_info.get_other_rank(tnt.get_rank()) < tnt.get_rank(): continue array_length = micro_batch_size * conn_info.get_size_in_bytes( ) // elem_type.itemsize input_array = np.empty(shape=(array_length, 1), dtype=elem_type) input_array.fill(tnt.get_rank()) pipeline_comm.send(input_array, connection_id=conn_id, micro_batch_id=micro_batch_id) # receive on all connections for micro_batch_id in range(num_micro_batches): for conn_id in pipeline_comm.get_local_connection_ids(): conn_info = partition[conn_id] if conn_info.get_other_rank(tnt.get_rank()) > tnt.get_rank(): continue array_length = micro_batch_size * conn_info.get_size_in_bytes( ) // elem_type.itemsize expected_array = np.empty(shape=(array_length, 1), dtype=elem_type) expected_array.fill(conn_info.get_other_rank(tnt.get_rank())) input_array = np.empty(shape=(array_length, 1)) result = pipeline_comm.recv(input_array, connection_id=conn_id, micro_batch_id=micro_batch_id) assert np.allclose(result, expected_array, atol=1e-6)
def __init__(self, model, group, partition_generator, rank_mapper, num_pipeline_stages = None): super().__init__(model = model, group = group) self._model_name = model.name self.built = False self.compile_properties = None self.num_pipeline_stages = num_pipeline_stages connection_table = rank_mapper.get_connections_for_rank(self.rank) self.pipeline_communicator = tnt.PipelineCommunicator(connection_table, self.num_pipeline_stages) self.initialized = False partition_id = rank_mapper.get_partition_for_rank(self.rank) partition_graph = partition_generator.get_partition_graph(partition_id) self.partition_info = pinfo.PartitionInfo(partition_id, partition_graph) core_model_builder = cm_builder.CoreModelBuilder(model, partition_id, partition_graph) self.model = core_model_builder.get_model() self.nano_batch_size = None self.built = False
def test_send_recv_layers_forward(self, test_case): rank_0 = 0 rank_1 = 1 rank = tnt.get_rank() number_tags = 2 # mbatch_id, connection_id elem_type = np.dtype(np.float32) connection_id = test_case["connection_id"] micro_batch_size = test_case["micro_batch_size"] num_micro_batches = test_case["num_micro_batches"] data_to_be_sent = test_case["data_to_be_sent"] tensor_size = np.array(data_to_be_sent[0]).size tags_dataset = [] for mbatch_id in range(num_micro_batches): tags_dataset = tags_dataset + micro_batch_size * [[ mbatch_id, connection_id ]] connection_table = { connection_id: cinfo.ConnectionInfo((rank_0, rank_1), tensor_size * elem_type.itemsize) } pipeline_comm = tnt.PipelineCommunicator(connection_table, num_micro_batches) pipeline_comm.setup_infrastructure(micro_batch_size) if rank == rank_0: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.SendLayer( pipeline_communicator=pipeline_comm)(inputs, input_tags) model = keras.Model([inputs, input_tags], outputs) dataset = data_to_be_sent if rank == rank_1: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RecvLayer( pipeline_communicator=pipeline_comm)(inputs, input_tags) model = keras.Model([inputs, input_tags], outputs) dataset = np.zeros_like(data_to_be_sent) def generator(): for data, tag in zip(dataset, tags_dataset): yield {"input": data, "tags": tag} final_dataset = tf.data.Dataset.from_generator(generator, output_types={ "input": tf.float32, "tags": tf.int32 }) data_received = model.predict(final_dataset.batch(micro_batch_size)) if rank == rank_1: assert np.allclose(data_received, data_to_be_sent)
def test_send_recv_layers_forward_and_backward(self, test_case): rank_0 = 0 rank_1 = 1 rank = tnt.get_rank() number_tags = 2 # mbatch_id, connection_id elem_type = np.dtype(np.float32) number_epochs = 3 connection_id = test_case["connection_id"] micro_batch_size = test_case["micro_batch_size"] num_micro_batches = test_case["num_micro_batches"] data_to_be_sent = test_case["data_to_be_sent"] tensor_size = np.array(data_to_be_sent[0]).size tags_dataset = [] for mbatch_id in range(num_micro_batches): tags_dataset = tags_dataset + micro_batch_size * [[ mbatch_id, connection_id ]] labels_dataset = micro_batch_size * num_micro_batches * [0.] connection_table = { connection_id: cinfo.ConnectionInfo((rank_0, rank_1), tensor_size * elem_type.itemsize) } pipeline_comm = tnt.PipelineCommunicator(connection_table, num_micro_batches) pipeline_comm.setup_infrastructure(micro_batch_size) if rank == rank_0: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) input_seq = keras.Input(shape=(1, ), name='input_seq') inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RemoveSeqInput()( [inputs, input_seq]) # force execution of backward pass outputs = tnt_layers.SendLayer( pipeline_communicator=pipeline_comm)(outputs, input_tags) outputs = tnt_layers.IdentityLayer(name="result")(outputs) model = keras.Model([inputs, input_tags, input_seq], outputs) loss = tnt_losses.ZeroLoss() dataset = data_to_be_sent if rank == rank_1: input_tags = keras.Input(shape=(number_tags, ), name="tags", dtype=tf.int32) input_seq = keras.Input(shape=(1, ), name='input_seq') inputs = keras.Input(shape=(tensor_size, ), name='input') outputs = tnt_layers.RemoveSeqInput()([inputs, input_seq]) outputs = tnt_layers.RecvLayer( pipeline_communicator=pipeline_comm)(outputs, input_tags) outputs = tnt_layers.IdentityLayer(name="result")(outputs) model = keras.Model([inputs, input_tags, input_seq], outputs) # loss = PlusOneLoss() loss = tnt_losses.ZeroLoss() dataset = np.zeros_like(data_to_be_sent) def generator(): input_seq_constant = 0 for data, tag, label in zip(dataset, tags_dataset, labels_dataset): yield ({ "input": data, "tags": tag, "input_seq": input_seq_constant }, { "result": label }) final_dataset = tf.data.Dataset.from_generator(generator, output_types=({ "input": tf.float32, "tags": tf.int32, "input_seq": tf.float32 }, { "result": tf.float32 })) final_dataset = final_dataset.batch(micro_batch_size) model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.1), loss=loss) history = model.fit(final_dataset, epochs=number_epochs) assert len(history.history['loss']) == number_epochs
def get_pipeline_communicator(num_micro_batches): connection_table = { 0 : cinfo.ConnectionInfo((p_0_rank, p_1_rank), fc_units * elem_type.itemsize), 1 : cinfo.ConnectionInfo((p_0_rank, p_1_rank), fc_units * elem_type.itemsize) } ppl_comm = tnt.PipelineCommunicator(connection_table, num_micro_batches) return ppl_comm