Ejemplo n.º 1
0
def simple_partition_info(ref_model, rank):
    partition_info = pinfo.PartitionInfo('0')

    in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
    partition_info.real_input_infos = {0: in_0}
    out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
    partition_info.real_output_infos = {0: out_0}
    return partition_info
Ejemplo n.º 2
0
def multi_output_partition_info(ref_model, rank):
    if rank == p_0_rank:
        partition_info = pinfo.PartitionInfo('0')

        in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
        partition_info.real_input_infos = {0: in_0}
        out_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('ten_classes').output.shape, tf.float32)
        out_edge_1 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('two_classes').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0, 1: out_edge_1}

    elif rank == p_1_rank:
        partition_info = pinfo.PartitionInfo('1')
        in_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('ten_classes').output.shape, tf.float32)
        in_edge_1 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('two_classes').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0, 1: in_edge_1}

        out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
        out_1 = pinfo.EndpointInfo(1, ref_model.outputs[1].shape, tf.float32)
        partition_info.real_output_infos = {0: out_0, 1: out_1}
    return partition_info
Ejemplo n.º 3
0
def fc_partition_info(ref_model, rank):
    if rank == p_0_rank:
        partition_info = pinfo.PartitionInfo('0')

        in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
        partition_info.real_input_infos = {0: in_0}
        out_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0}

    elif rank == p_1_rank:
        partition_info = pinfo.PartitionInfo('1')
        in_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0}

        out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
        partition_info.real_output_infos = {0: out_0}
    return partition_info
Ejemplo n.º 4
0
def get_partition_info(core_model):
  if rank == p_0_rank:
    partition_info = pinfo.PartitionInfo(p_0_id)

    in_0 = pinfo.EndpointInfo(0, core_model.inputs[0].shape, tf.float32)
    partition_info.real_input_infos = {0 : in_0}
    partition_info.edge_input_infos = {}
    partition_info.real_output_infos = {}

    out_edge_0 = pinfo.EndpointInfo(0, core_model.outputs[0].shape, tf.float32)
    out_edge_1 = pinfo.EndpointInfo(1, core_model.outputs[1].shape, tf.float32)
    partition_info.edge_output_infos = {0 : out_edge_0, 1 : out_edge_1}

  elif rank == p_1_rank:
    partition_info = pinfo.PartitionInfo(p_1_id)
    partition_info.real_input_infos = {}

    in_edge_0 = pinfo.EndpointInfo(0, core_model.inputs[0].shape, tf.float32)
    in_edge_1 = pinfo.EndpointInfo(1, core_model.inputs[1].shape, tf.float32)
    partition_info.edge_input_infos = {0 : in_edge_0, 1 : in_edge_1}

    out_0 = pinfo.EndpointInfo(0, core_model.outputs[0].shape, tf.float32)
    partition_info.real_output_infos = {0 : out_0}
    partition_info.edge_output_infos = {}
  return partition_info
Ejemplo n.º 5
0
def skip_connection_partition_info(ref_model, rank):
    if rank == p_0_rank:
        partition_info = pinfo.PartitionInfo('0')
        in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32)
        partition_info.real_input_infos = {0: in_0}

        out_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer0').output.shape, tf.float32)
        out_edge_1 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0, 1: out_edge_1}

    elif rank == p_1_rank:
        partition_info = pinfo.PartitionInfo('1')
        in_edge_0 = pinfo.EndpointInfo(
            0,
            ref_model.get_layer('split_layer0').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0}

        out_edge_0 = pinfo.EndpointInfo(
            2,
            ref_model.get_layer('split_layer2').output.shape, tf.float32)
        partition_info.edge_output_infos = {0: out_edge_0}

    elif rank == p_2_rank:
        partition_info = pinfo.PartitionInfo('2')
        in_edge_0 = pinfo.EndpointInfo(
            1,
            ref_model.get_layer('split_layer1').output.shape, tf.float32)
        in_edge_1 = pinfo.EndpointInfo(
            2,
            ref_model.get_layer('split_layer2').output.shape, tf.float32)
        partition_info.edge_input_infos = {0: in_edge_0, 1: in_edge_1}

        out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32)
        partition_info.real_output_infos = {0: out_0}
    return partition_info