def no_trivial_ax_input(ast_seq, global_input_dict, global_output_dict): for i, ast in enumerate(ast_seq): ax_elim = [] for ax in ast['props']['reduce_axes'] + ast['props']['data_axes']: if ax['range'] == 1: ax_elim.append(ax['name']) ast['props']['reduce_axes'] = [ x for x in ast['props']['reduce_axes'] if x['name'] not in ax_elim ] if len(ast_seq) == 1 and len(global_output_dict) == 1: ax_rebuld = [ x for x in ast['props']['data_axes'] if x['name'] not in ax_elim ] if ax_rebuld: ast['props']['data_axes'] = ax_rebuld else: ast['props']['data_axes'] = [ast['props']['data_axes'][0]] if not ast['props']['reduce_axes']: ast['props']['reduce_type'] = None def scan_trivial_axis(root, ancestor, ax_elim): if root._op == 'axis' and root._value in ax_elim: return OpTensor('const', 0, 'int32') walk_in_ast(ast, 'root', scan_trivial_axis, [ax_elim])
def scan_items(root, ast, access_book, tensor_nodes): if root._op == 'axis': access_book['*'].append(root._value['name']) if root._op != 'get_item': return tensor_nodes.append(root) tensor_name = root._value['tensor']._value['name'] tensor_index = [] for i, sub in enumerate(root._value['index']): if sub._op == 'axis': rng = get_daxis_range(sub._value['name'], ast) shp = get_input_shape(tensor_name, ast) if rng == shp[i]: tensor_index.append(sub._value['name']) else: tensor_index.append('*') else: tensor_index.append('*') walk_in_ast(sub, scan_items, [ast, access_book, tensor_nodes], root._value['index'], i) if tensor_name in access_book: last_index = access_book[tensor_name] assert len(last_index) == len(tensor_index) for i in range(len(tensor_index)): if tensor_index[i] != last_index[i]: tensor_index[i] = '*' else: access_book[tensor_name] = tensor_index return ''
def eliminate_trivial_axis(ast): print(ast['props']) walk_in_ast(ast['root'], scan_trivial_axis, [ast], ast, 'root') def update(axes, start=0): new_ra = axes[:start] for ax in axes[start:]: if ax['range'] != 1: new_ra.append(ax) axes = new_ra return axes ast['props']['reduce_axes'] = update(ast['props']['reduce_axes']) if not ast['props']['reduce_axes']: ast['props']['reduce_type'] = None for k in ast['props']['output_dict']: num_outputs = int(np.product(ast['props']['output_dict'][k]['shape'])) break if num_outputs > 1: ast['props']['data_axes'] = update(ast['props']['data_axes']) for k in ast['props']['output_dict']: ast['props']['output_dict'][k]['shape'] = [x for x in filter(lambda x: x != 1, ast['props']['output_dict'][k]['shape'])] elif len(ast['props']['data_axes']) > 1: ast['props']['data_axes'] = update(ast['props']['data_axes'], 1) for k in ast['props']['output_dict']: ast['props']['output_dict'][k]['shape'] = [1]
def run_pass_v2(ast_seq, global_input_dict, global_output_dict): if backend not in ['c-gc']: return if len(ast_seq) > 1: raise Exception( "TODO: Graphcore backend not handling multiple IR statements.") ast = ast_seq[0] steps = int(os.environ.get('STEP', '0')) pieces = os.environ.get('CONFIG', '').strip() data_axes = ast['props']['data_axes'] if not pieces and steps > 0: return try: pieces = json.loads(pieces) pieces = [pieces['axis_%d' % i][-1] for i in range(len(data_axes))] except: pieces = [1] * len(data_axes) assert 'injective' not in ast, "Unhandled injective case for graphcore." range_book = {} walk_in_ast(ast['root'], scan_items, [ast, range_book], ast, 'root') ast['props']['shard'] = {'nparts': pieces, 'book': range_book} # AST props: ast['props']['data_axes'], ast['props']['input_dict'] for i in range(len(pieces)): assert data_axes[i]['range'] % pieces[ i] == 0, "Axis sharding must be exactly divided, while requesting %d // %d." % ( data_axes[i]['range'], pieces[i]) data_axes[i]['range'] //= pieces[i] for k in ast['props']['input_dict']: input_item = ast['props']['input_dict'][k] sub_shape = [] for it in range_book[k]: bias_diff = it[3] - it[2] + 1 if it[1] < 0 or it[0] == 0: sub_shape.append(bias_diff) elif it[0] > 0: sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) + bias_diff) else: raise Exception('Unhandled book case:', it) input_item['shape'] = sub_shape from antares.common import local_get_dir_file output_key = ast['props']['output_name'] ast['props']['shard']['local_shape'] = [ x['range'] for x in ast['props']['data_axes'] ] with open(local_get_dir_file('range_book.json'), 'w') as fp: json.dump(ast['props']['shard'], fp)
def compute(ast): if os.environ.get('SIMPLE', '1') == '0': return if 'injective' in ast or 'shard' in ast['props']: # FIXME: Unhandled case yet return annotation = os.environ.get('COMPUTE_V1', '').split('##')[-1] # FIXME: Just a rough check if 'plan/' in annotation and 'default' not in annotation: return eliminate_trivial_axis(ast) access_book, tensor_nodes = {}, [] access_book['*'] = [] walk_in_ast(ast['root'], scan_items, [ast, access_book, tensor_nodes], ast, 'root') data_axes = ast['props']['data_axes'] access_book['='] = [x['name'] for x in data_axes] unique_axes = set() for k in access_book: if k != '*': access_book[k] = [ x if x not in access_book['*'] else '*' for x in access_book[k] ] for x in access_book[k]: if x != '*': unique_axes.add(x) access_book.pop('*') # print(access_book, unique_axes) visited = set() for size in reversed(range(2, len(unique_axes) + 1)): for k in itertools.permutations(unique_axes, size): if sum([1 if x in visited else 0 for x in k]) > 0: continue this_pattern = ':%s:' % ':'.join(k) access_pattern = [ ':%s:' % ':'.join(access_book[x]) for x in access_book ] can_simplify = True for acc in access_pattern: rest_acc = ''.join(acc.split(this_pattern)).split(':') if sum([1 if x in k else 0 for x in rest_acc]) > 0: can_simplify = False if can_simplify: for x in k: visited.add(x) update_ast_axis(ast, k, tensor_nodes) # print(k, this_pattern, access_pattern, can_simplify) return
def compute(ast): if backend not in ['c-gc']: return steps = int(os.environ.get('STEP', '0')) pieces = os.environ.get('CONFIG', '').strip() data_axes = ast['props']['data_axes'] if not pieces and steps > 0: return try: pieces = json.loads(pieces) pieces = [pieces['axis_%d' % i][-1] for i in range(len(data_axes))] except: pieces = [1] * len(data_axes) assert 'injective' not in ast, "Unhandled injective case for graphcore." range_book = {} walk_in_ast(ast['root'], scan_items, [ast, range_book], ast, 'root') ast['props']['shard'] = {'nparts': pieces, 'book': range_book} # AST props: ast['props']['data_axes'], ast['props']['output_dict'], ast['props']['input_dict'] for i in range(len(pieces)): assert data_axes[i]['range'] % pieces[i] == 0 data_axes[i]['range'] //= pieces[i] for k in ast['props']['output_dict']: output_item = ast['props']['output_dict'][k] assert output_item['shape'][i] % pieces[i] == 0 output_item['shape'][i] //= pieces[i] for k in ast['props']['input_dict']: input_item = ast['props']['input_dict'][k] sub_shape = [] for it in range_book[k]: bias_diff = it[3] - it[2] + 1 if it[1] < 0 or it[0] == 0: sub_shape.append(bias_diff) elif it[0] > 0: sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) + bias_diff) else: raise Exception('Unhandled book case:', it) input_item['shape'] = sub_shape from antares.common import local_get_dir_file output_key = next(iter(ast['props']['output_dict'])) ast['props']['shard']['local_shape'] = ast['props']['output_dict'][output_key]['shape'] with open(local_get_dir_file('range_book.json'), 'w') as fp: json.dump(ast['props']['shard'], fp)
def update_ast(config, ast_seq, global_input_dict, global_output_dict): if len(ast_seq) > 1: raise Exception( "TODO: Graphcore backend not handling multiple IR statements.") ast = ast_seq[0] data_axes = ast['props']['data_axes'] try: pieces = config pieces = [(pieces['tile_%d' % i][1] * pieces['tile_%d' % i][2]) for i in range(len(data_axes))] except: pieces = [1] * len(data_axes) for i in range(len(pieces)): assert data_axes[i]['range'] % pieces[i] == 0 pieces[i] = data_axes[i]['range'] // pieces[i] assert 'injective' not in ast, "Unhandled injective case for graphcore." range_book = {} walk_in_ast(ast, 'root', scan_items, [ast, range_book]) ast['props']['shard'] = {'nparts': pieces, 'book': range_book} # AST props: ast['props']['data_axes'], ast['props']['input_dict'] for i in range(len(pieces)): assert data_axes[i]['range'] % pieces[ i] == 0, "Axis sharding must be exactly divided, while requesting %d // %d." % ( data_axes[i]['range'], pieces[i]) data_axes[i]['range'] //= pieces[i] for k in ast['props']['input_dict']: input_item = ast['props']['input_dict'][k] sub_shape = [] for it in range_book[k]: bias_diff = it[3] - it[2] + 1 if it[1] < 0 or it[0] == 0: sub_shape.append(bias_diff) elif it[0] > 0: sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) + bias_diff) else: raise Exception('Unhandled book case:', it) input_item['shape'] = sub_shape from antares.common import local_get_dir_file output_key = ast['props']['output_name'] ast['props']['shard']['local_shape'] = [ x['range'] for x in ast['props']['data_axes'] ] with open(local_get_dir_file('range_book.json'), 'w') as fp: json.dump(ast['props']['shard'], fp) for k in global_input_dict: if k in ast['props']['input_dict']: global_input_dict[k] = ast['props']['input_dict'][k] assert len(global_output_dict) == 1 for k in global_output_dict: global_output_dict[k]['shape'] = [ x['range'] for x in ast['props']['data_axes'] ] break