def test_scatter_gather_node_axes(config): t = config axes = ng.make_axes([ng.make_axis(length) for length in t['axes']]) parallel_axis = axes[t['parallel_axis']] hetr_axes = parallel_axis + (axes - parallel_axis) with ng.metadata(device=None, device_id='0', transformer='cpu0', host_transformer=None): from_node = ng.placeholder(axes=axes) to_node = ng.placeholder(axes=axes) with ng.metadata(device=None, device_id=t['device_id'], transformer=None, parallel=parallel_axis, host_transformer=None): par_node = ng.placeholder(axes=axes) scatter_send_op = ScatterSendOp(from_node=from_node, to_node=par_node) assert hetr_axes == scatter_send_op.axes assert t['slices'] == scatter_send_op.slices scatter_recv_op = ScatterRecvOp(to_node=par_node, send_node=scatter_send_op) for sct_a, a in zip(scatter_recv_op.axes, hetr_axes): assert sct_a.length == a.length gather_send_op = GatherSendOp(from_node=scatter_recv_op) assert_axes_eq_len(scatter_recv_op.axes, gather_send_op.axes) gather_recv_op = GatherRecvOp(from_node=par_node, to_node=to_node, send_node=gather_send_op) assert_axes_eq_len(hetr_axes, gather_recv_op.axes) assert t['slices'] == gather_recv_op.slices
def test_update_comm_deps_scatter_gather(): ax_a = ng.make_axis(length=10, name='A') ax_b = ng.make_axis(length=15, name='B') axes = ng.make_axes([ax_a, ax_b]) parallel_metadata = dict(parallel=ax_a, device_id=(0, 1), transformer=None, host_transformer=None, device=None) with ng.metadata(transformer='cpu0'): with ng.metadata(**parallel_metadata): from_node_a = ng.placeholder(axes) to_node_a = ng.placeholder(axes) scatter_send_x = ScatterSendOp(from_node=from_node_a, to_node=to_node_a) scatter_recv_a = ScatterRecvOp(to_node=to_node_a, send_node=scatter_send_x) with ng.metadata(**parallel_metadata): x_plus_one_a = scatter_recv_a + 1 gather_send_x_plus_one_a = GatherSendOp(from_node=x_plus_one_a) with ng.metadata(transformer='cpu1'): with ng.metadata(**parallel_metadata): to_node_b = ng.placeholder(axes) scatter_recv_b = ScatterRecvOp(to_node=to_node_b, send_node=scatter_send_x) with ng.metadata(**parallel_metadata): x_plus_one_b = scatter_recv_b + 1 gather_send_x_plus_one_b = GatherSendOp(from_node=x_plus_one_b) with ng.metadata(transformer='cpu0'): with ng.metadata(**parallel_metadata): gather_recv_x_plus_one_a = GatherRecvOp( from_node=from_node_a, to_node=to_node_a, send_node=gather_send_x_plus_one_a) z_a = gather_recv_x_plus_one_a + 1 update_comm_deps((scatter_send_x, gather_send_x_plus_one_a, z_a)) update_comm_deps((gather_send_x_plus_one_b, )) assert set([scatter_send_x]) == set(scatter_recv_a.control_deps) assert set([scatter_send_x, gather_send_x_plus_one_a]) == \ set(gather_recv_x_plus_one_a.control_deps)
def create_scatter_gather_graph(): ax_a = ng.make_axis(length=10, name='A') ax_b = ng.make_axis(length=20, name='B') axes = ng.make_axes([ax_a, ax_b]) with ng.metadata(parallel=ax_b, device=(0, 1), device_id=(0, 1), transformer=None, host_transformer=None): from_node = ng.placeholder(axes) to_node = ng.placeholder(axes) scatter_send_x = ScatterSendOp(from_node=from_node, to_node=to_node) scatter_recv_a = ScatterRecvOp(to_node=to_node, send_node=scatter_send_x) scatter_recv_b = ScatterRecvOp(to_node=to_node, send_node=scatter_send_x) gather_send_a = GatherSendOp(from_node=scatter_recv_a) gather_send_b = GatherSendOp(from_node=scatter_recv_b) gather_recv_x_plus_one = GatherRecvOp(from_node=from_node, to_node=to_node, send_node=gather_send_a) return scatter_send_x, scatter_recv_a, scatter_recv_b, \ gather_send_a, gather_send_b, gather_recv_x_plus_one