def multi_output_partition_info(ref_model, rank):
    if rank == p_0_rank:
        partition_info = pinfo.PartitionInfo('0')

        in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
        partition_info.real_input_infos = {0: in_0}
        out_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('ten_classes').output.shape, tf.float32)
        out_edge_1 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('two_classes').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0, 1: out_edge_1}

    elif rank == p_1_rank:
        partition_info = pinfo.PartitionInfo('1')
        in_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('ten_classes').output.shape, tf.float32)
        in_edge_1 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('two_classes').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0, 1: in_edge_1}

        out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
        out_1 = pinfo.EndpointInfo(1, ref_model.outputs[1].shape, tf.float32)
        partition_info.real_output_infos = {0: out_0, 1: out_1}
    return partition_info
def alexnet_partition_info(ref_model, rank):
    if rank == p_0_rank:
        partition_info = pinfo.PartitionInfo('0')

        in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
        partition_info.real_input_infos = {0: in_0}

        out_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer0').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0}

    elif rank == p_1_rank:
        partition_info = pinfo.PartitionInfo('1')
        in_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer0').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0}

        out_edge_0 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0}

    elif rank == p_2_rank:
        partition_info = pinfo.PartitionInfo('2')
        in_edge_0 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0}

        out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
        partition_info.real_output_infos = {0: out_0}
    return partition_info
Exemple #3
0
def to_microbatched(model, micro_batch_size, num_micro_batches, num_batches, num_test_batches):
  rank = tnt.get_rank()
  partition_generator = pgen.GraphPartitionGenerator(model)
  rank_mapper = rmapper.RankMapper(num_ranks = tnt.get_size(),
                                   pipeline_graph = partition_generator.get_pipeline_graph())

  partition_id = rank_mapper.get_partition_for_rank(rank)
  partition_graph = partition_generator.get_partition_graph(partition_id)
  partition_info = pinfo.PartitionInfo(partition_id = partition_id,
                                       partition_graph = partition_graph)

  core_model_builder = cm_builder.CoreModelBuilder(model, partition_id, partition_graph)
  core_model = core_model_builder.get_model()

  connection_table = rank_mapper.get_connections_for_rank(rank)
  pipeline_communicator = tnt.PipelineCommunicator(connection_table, num_micro_batches)

  shared_model_builder = shared.SharedModelBuilder(partition_info, core_model,
                                                   pipeline_communicator, micro_batch_size)
  shared_model = shared_model_builder.get_model()

  microbatched_model_builder = microbatched.MicrobatchedModelBuilder(partition_info, shared_model,
                                                                     micro_batch_size, num_micro_batches)
  ds = load_microbatched_datasets(micro_batch_size, num_micro_batches,
                                  num_batches, num_test_batches, partition_info)
  pipeline_communicator.setup_infrastructure(micro_batch_size)
  return microbatched_model_builder, ds
def simple_partition_info(ref_model, rank):
    partition_info = pinfo.PartitionInfo('0')

    in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
    partition_info.real_input_infos = {0: in_0}
    out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
    partition_info.real_output_infos = {0: out_0}
    return partition_info
    def test_partition_info(self, model_and_partitions):
        model, partition_gen, expected_num_partitions, expected_partition_gen, _ = model_and_partitions
        rank_mapper = rmapper.RankMapper(
            num_ranks=expected_num_partitions,
            pipeline_graph=partition_gen.get_pipeline_graph())

        for rank in range(expected_num_partitions):
            partition_id = rank_mapper.get_partition_for_rank(rank)
            partition_info = pinfo.PartitionInfo(
                partition_id=partition_id,
                partition_graph=partition_gen.get_partition_graph(
                    partition_id))
            assert partition_info == expected_partition_gen(model, rank)
Exemple #6
0
  def __init__(self, model, group, partition_generator, rank_mapper,
               num_pipeline_stages = None):
    super().__init__(model = model, group = group)
    self._model_name = model.name
    self.built = False
    self.compile_properties = None
    self.num_pipeline_stages = num_pipeline_stages

    connection_table = rank_mapper.get_connections_for_rank(self.rank)
    self.pipeline_communicator = tnt.PipelineCommunicator(connection_table, self.num_pipeline_stages)
    self.initialized = False

    partition_id = rank_mapper.get_partition_for_rank(self.rank)
    partition_graph = partition_generator.get_partition_graph(partition_id)
    self.partition_info = pinfo.PartitionInfo(partition_id, partition_graph)

    core_model_builder = cm_builder.CoreModelBuilder(model, partition_id,
                                                     partition_graph)
    self.model = core_model_builder.get_model()
    self.nano_batch_size = None
    self.built = False