def test_switch_cf_false_no_exec(self): me_mock = Mock() nodes = { 'tensor': { 'value': True, 'kind': 'data', 'executable': True }, 'pred_id': { 'value': np.array(False), 'kind': 'data', 'executable': True }, 'switch': { 'type': 'Switch', 'kind': 'op', 'op': 'Switch', 'control_flow_infer': Switch.control_flow_infer }, 'switch_data_1': { 'value': None, 'kind': 'data', 'executable': True }, 'result_1': { 'value': None, 'kind': 'op', 'executable': True, 'type': 'Result', 'op': 'Result' }, } edges = [ ('tensor', 'switch', { 'in': 0 }), ('pred_id', 'switch', { 'in': 1 }), ('switch', 'switch_data_1', { 'out': 1 }), ('switch_data_1', 'result_1', { 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) node = Node(graph, 'switch') node.control_flow_infer(node, True, me_mock) me_mock.assert_has_calls([call('switch_data_1', False)], any_order=True)
def test_switch_cf_infer_no_condition(self): me_mock = Mock() nodes = { 'tensor': { 'value': True, 'kind': 'data', 'executable': True }, 'pred_id': { 'value': None, 'kind': 'data', 'executable': True }, 'switch': { 'type': 'Switch', 'kind': 'op', 'op': 'Switch', 'control_flow_infer': Switch.control_flow_infer }, 'switch_data_0': { 'value': None, 'kind': 'data', 'executable': True }, 'switch_data_1': { 'value': None, 'kind': 'data', 'executable': True }, 'result_0': { 'value': None, 'kind': 'op', 'executable': True, 'type': 'Result', 'op': 'Result' }, 'result_1': { 'value': None, 'kind': 'op', 'executable': True, 'type': 'Result', 'op': 'Result' }, } edges = [ ('tensor', 'switch', { 'in': 0 }), ('pred_id', 'switch', { 'in': 1 }), ('switch', 'switch_data_0', { 'out': 0 }), ('switch', 'switch_data_1', { 'out': 1 }), ('switch_data_0', 'result_0', { 'in': 0 }), ('switch_data_1', 'result_1', { 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) node = Node(graph, 'switch') node.control_flow_infer(node, True, me_mock) # In this case we should mark all ports as executable me_mock.assert_has_calls( [call('switch_data_0', True), call('switch_data_1', True)], any_order=True)