def simple_partition_info(ref_model, rank): partition_info = pinfo.PartitionInfo('0') in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32) partition_info.real_input_infos = {0: in_0} out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32) partition_info.real_output_infos = {0: out_0} return partition_info
def multi_output_partition_info(ref_model, rank): if rank == p_0_rank: partition_info = pinfo.PartitionInfo('0') in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32) partition_info.real_input_infos = {0: in_0} out_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('ten_classes').output.shape, tf.float32) out_edge_1 = pinfo.EndpointInfo( 1, ref_model.get_layer('two_classes').output.shape, tf.float32) partition_info.edge_output_infos = {0: out_edge_0, 1: out_edge_1} elif rank == p_1_rank: partition_info = pinfo.PartitionInfo('1') in_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('ten_classes').output.shape, tf.float32) in_edge_1 = pinfo.EndpointInfo( 1, ref_model.get_layer('two_classes').output.shape, tf.float32) partition_info.edge_input_infos = {0: in_edge_0, 1: in_edge_1} out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32) out_1 = pinfo.EndpointInfo(1, ref_model.outputs[1].shape, tf.float32) partition_info.real_output_infos = {0: out_0, 1: out_1} return partition_info
def fc_partition_info(ref_model, rank): if rank == p_0_rank: partition_info = pinfo.PartitionInfo('0') in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32) partition_info.real_input_infos = {0: in_0} out_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('split_layer1').output.shape, tf.float32) partition_info.edge_output_infos = {0: out_edge_0} elif rank == p_1_rank: partition_info = pinfo.PartitionInfo('1') in_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('split_layer1').output.shape, tf.float32) partition_info.edge_input_infos = {0: in_edge_0} out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32) partition_info.real_output_infos = {0: out_0} return partition_info
def get_partition_info(core_model): if rank == p_0_rank: partition_info = pinfo.PartitionInfo(p_0_id) in_0 = pinfo.EndpointInfo(0, core_model.inputs[0].shape, tf.float32) partition_info.real_input_infos = {0 : in_0} partition_info.edge_input_infos = {} partition_info.real_output_infos = {} out_edge_0 = pinfo.EndpointInfo(0, core_model.outputs[0].shape, tf.float32) out_edge_1 = pinfo.EndpointInfo(1, core_model.outputs[1].shape, tf.float32) partition_info.edge_output_infos = {0 : out_edge_0, 1 : out_edge_1} elif rank == p_1_rank: partition_info = pinfo.PartitionInfo(p_1_id) partition_info.real_input_infos = {} in_edge_0 = pinfo.EndpointInfo(0, core_model.inputs[0].shape, tf.float32) in_edge_1 = pinfo.EndpointInfo(1, core_model.inputs[1].shape, tf.float32) partition_info.edge_input_infos = {0 : in_edge_0, 1 : in_edge_1} out_0 = pinfo.EndpointInfo(0, core_model.outputs[0].shape, tf.float32) partition_info.real_output_infos = {0 : out_0} partition_info.edge_output_infos = {} return partition_info
def skip_connection_partition_info(ref_model, rank): if rank == p_0_rank: partition_info = pinfo.PartitionInfo('0') in_0 = pinfo.EndpointInfo(0, ref_model.inputs[0].shape, tf.float32) partition_info.real_input_infos = {0: in_0} out_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('split_layer0').output.shape, tf.float32) out_edge_1 = pinfo.EndpointInfo( 1, ref_model.get_layer('split_layer1').output.shape, tf.float32) partition_info.edge_output_infos = {0: out_edge_0, 1: out_edge_1} elif rank == p_1_rank: partition_info = pinfo.PartitionInfo('1') in_edge_0 = pinfo.EndpointInfo( 0, ref_model.get_layer('split_layer0').output.shape, tf.float32) partition_info.edge_input_infos = {0: in_edge_0} out_edge_0 = pinfo.EndpointInfo( 2, ref_model.get_layer('split_layer2').output.shape, tf.float32) partition_info.edge_output_infos = {0: out_edge_0} elif rank == p_2_rank: partition_info = pinfo.PartitionInfo('2') in_edge_0 = pinfo.EndpointInfo( 1, ref_model.get_layer('split_layer1').output.shape, tf.float32) in_edge_1 = pinfo.EndpointInfo( 2, ref_model.get_layer('split_layer2').output.shape, tf.float32) partition_info.edge_input_infos = {0: in_edge_0, 1: in_edge_1} out_0 = pinfo.EndpointInfo(0, ref_model.outputs[0].shape, tf.float32) partition_info.real_output_infos = {0: out_0} return partition_info