コード例 #1
0
def no_trivial_ax_input(ast_seq, global_input_dict, global_output_dict):
    for i, ast in enumerate(ast_seq):
        ax_elim = []
        for ax in ast['props']['reduce_axes'] + ast['props']['data_axes']:
            if ax['range'] == 1:
                ax_elim.append(ax['name'])
        ast['props']['reduce_axes'] = [
            x for x in ast['props']['reduce_axes'] if x['name'] not in ax_elim
        ]
        if len(ast_seq) == 1 and len(global_output_dict) == 1:
            ax_rebuld = [
                x for x in ast['props']['data_axes']
                if x['name'] not in ax_elim
            ]
            if ax_rebuld:
                ast['props']['data_axes'] = ax_rebuld
            else:
                ast['props']['data_axes'] = [ast['props']['data_axes'][0]]
        if not ast['props']['reduce_axes']:
            ast['props']['reduce_type'] = None

        def scan_trivial_axis(root, ancestor, ax_elim):
            if root._op == 'axis' and root._value in ax_elim:
                return OpTensor('const', 0, 'int32')

        walk_in_ast(ast, 'root', scan_trivial_axis, [ax_elim])
コード例 #2
0
ファイル: simplify.py プロジェクト: SerailHydra/antares
def scan_items(root, ast, access_book, tensor_nodes):
  if root._op == 'axis':
    access_book['*'].append(root._value['name'])
  if root._op != 'get_item':
    return
  tensor_nodes.append(root)
  tensor_name = root._value['tensor']._value['name']
  tensor_index = []
  for i, sub in enumerate(root._value['index']):
    if sub._op == 'axis':
      rng = get_daxis_range(sub._value['name'], ast)
      shp = get_input_shape(tensor_name, ast)
      if rng == shp[i]:
        tensor_index.append(sub._value['name'])
      else:
        tensor_index.append('*')
    else:
      tensor_index.append('*')
      walk_in_ast(sub, scan_items, [ast, access_book, tensor_nodes], root._value['index'], i)
  if tensor_name in access_book:
    last_index = access_book[tensor_name]
    assert len(last_index) == len(tensor_index)
    for i in range(len(tensor_index)):
      if tensor_index[i] != last_index[i]:
        tensor_index[i] = '*'
  else:
    access_book[tensor_name] = tensor_index
  return ''
コード例 #3
0
ファイル: simplify.py プロジェクト: SerailHydra/antares
def eliminate_trivial_axis(ast):
  print(ast['props'])
  walk_in_ast(ast['root'], scan_trivial_axis, [ast], ast, 'root')

  def update(axes, start=0):
    new_ra = axes[:start]
    for ax in axes[start:]:
      if ax['range'] != 1:
        new_ra.append(ax)
    axes = new_ra
    return axes

  ast['props']['reduce_axes'] = update(ast['props']['reduce_axes'])
  if not ast['props']['reduce_axes']:
    ast['props']['reduce_type'] = None
  for k in ast['props']['output_dict']:
    num_outputs = int(np.product(ast['props']['output_dict'][k]['shape']))
    break
  if num_outputs > 1:
    ast['props']['data_axes'] = update(ast['props']['data_axes'])
    for k in ast['props']['output_dict']:
      ast['props']['output_dict'][k]['shape'] = [x for x in filter(lambda x: x != 1, ast['props']['output_dict'][k]['shape'])]
  elif len(ast['props']['data_axes']) > 1: 
    ast['props']['data_axes'] = update(ast['props']['data_axes'], 1)
    for k in ast['props']['output_dict']:
      ast['props']['output_dict'][k]['shape'] = [1]
コード例 #4
0
def run_pass_v2(ast_seq, global_input_dict, global_output_dict):
    if backend not in ['c-gc']:
        return

    if len(ast_seq) > 1:
        raise Exception(
            "TODO: Graphcore backend not handling multiple IR statements.")
    ast = ast_seq[0]

    steps = int(os.environ.get('STEP', '0'))
    pieces = os.environ.get('CONFIG', '').strip()

    data_axes = ast['props']['data_axes']
    if not pieces and steps > 0:
        return

    try:
        pieces = json.loads(pieces)
        pieces = [pieces['axis_%d' % i][-1] for i in range(len(data_axes))]
    except:
        pieces = [1] * len(data_axes)

    assert 'injective' not in ast, "Unhandled injective case for graphcore."
    range_book = {}
    walk_in_ast(ast['root'], scan_items, [ast, range_book], ast, 'root')
    ast['props']['shard'] = {'nparts': pieces, 'book': range_book}

    # AST props: ast['props']['data_axes'], ast['props']['input_dict']
    for i in range(len(pieces)):
        assert data_axes[i]['range'] % pieces[
            i] == 0, "Axis sharding must be exactly divided, while requesting %d // %d." % (
                data_axes[i]['range'], pieces[i])
        data_axes[i]['range'] //= pieces[i]

    for k in ast['props']['input_dict']:
        input_item = ast['props']['input_dict'][k]
        sub_shape = []
        for it in range_book[k]:
            bias_diff = it[3] - it[2] + 1
            if it[1] < 0 or it[0] == 0:
                sub_shape.append(bias_diff)
            elif it[0] > 0:
                sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) +
                                 bias_diff)
            else:
                raise Exception('Unhandled book case:', it)
        input_item['shape'] = sub_shape

    from antares.common import local_get_dir_file
    output_key = ast['props']['output_name']
    ast['props']['shard']['local_shape'] = [
        x['range'] for x in ast['props']['data_axes']
    ]
    with open(local_get_dir_file('range_book.json'), 'w') as fp:
        json.dump(ast['props']['shard'], fp)
コード例 #5
0
ファイル: simplify.py プロジェクト: MyPandaShaoxiang/antares
def compute(ast):
    if os.environ.get('SIMPLE', '1') == '0':
        return
    if 'injective' in ast or 'shard' in ast['props']:
        # FIXME: Unhandled case yet
        return

    annotation = os.environ.get('COMPUTE_V1', '').split('##')[-1]
    # FIXME: Just a rough check
    if 'plan/' in annotation and 'default' not in annotation:
        return

    eliminate_trivial_axis(ast)

    access_book, tensor_nodes = {}, []
    access_book['*'] = []
    walk_in_ast(ast['root'], scan_items, [ast, access_book, tensor_nodes], ast,
                'root')

    data_axes = ast['props']['data_axes']
    access_book['='] = [x['name'] for x in data_axes]
    unique_axes = set()
    for k in access_book:
        if k != '*':
            access_book[k] = [
                x if x not in access_book['*'] else '*' for x in access_book[k]
            ]
            for x in access_book[k]:
                if x != '*':
                    unique_axes.add(x)
    access_book.pop('*')
    # print(access_book, unique_axes)

    visited = set()
    for size in reversed(range(2, len(unique_axes) + 1)):
        for k in itertools.permutations(unique_axes, size):
            if sum([1 if x in visited else 0 for x in k]) > 0:
                continue
            this_pattern = ':%s:' % ':'.join(k)
            access_pattern = [
                ':%s:' % ':'.join(access_book[x]) for x in access_book
            ]
            can_simplify = True
            for acc in access_pattern:
                rest_acc = ''.join(acc.split(this_pattern)).split(':')
                if sum([1 if x in k else 0 for x in rest_acc]) > 0:
                    can_simplify = False
            if can_simplify:
                for x in k:
                    visited.add(x)
                update_ast_axis(ast, k, tensor_nodes)
            # print(k, this_pattern, access_pattern, can_simplify)
    return
コード例 #6
0
ファイル: auto_shard.py プロジェクト: SerailHydra/antares
def compute(ast):
  if backend not in ['c-gc']:
    return

  steps = int(os.environ.get('STEP', '0'))
  pieces = os.environ.get('CONFIG', '').strip()

  data_axes = ast['props']['data_axes']
  if not pieces and steps > 0:
    return

  try:
    pieces = json.loads(pieces)
    pieces = [pieces['axis_%d' % i][-1] for i in range(len(data_axes))]
  except:
    pieces = [1] * len(data_axes)

  assert 'injective' not in ast, "Unhandled injective case for graphcore."
  range_book = {}
  walk_in_ast(ast['root'], scan_items, [ast, range_book], ast, 'root')
  ast['props']['shard'] = {'nparts': pieces, 'book': range_book}

  # AST props: ast['props']['data_axes'], ast['props']['output_dict'], ast['props']['input_dict']
  for i in range(len(pieces)):
    assert data_axes[i]['range'] % pieces[i] == 0
    data_axes[i]['range'] //= pieces[i]
    for k in ast['props']['output_dict']:
      output_item = ast['props']['output_dict'][k]
      assert output_item['shape'][i] % pieces[i] == 0
      output_item['shape'][i] //= pieces[i]
  for k in ast['props']['input_dict']:
    input_item = ast['props']['input_dict'][k]
    sub_shape = []
    for it in range_book[k]:
      bias_diff = it[3] - it[2] + 1
      if it[1] < 0 or it[0] == 0:
        sub_shape.append(bias_diff)
      elif it[0] > 0:
        sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) + bias_diff)
      else:
        raise Exception('Unhandled book case:', it)
    input_item['shape'] = sub_shape

  from antares.common import local_get_dir_file
  output_key = next(iter(ast['props']['output_dict']))
  ast['props']['shard']['local_shape'] = ast['props']['output_dict'][output_key]['shape']
  with open(local_get_dir_file('range_book.json'), 'w') as fp:
    json.dump(ast['props']['shard'], fp)
コード例 #7
0
def update_ast(config, ast_seq, global_input_dict, global_output_dict):
    if len(ast_seq) > 1:
        raise Exception(
            "TODO: Graphcore backend not handling multiple IR statements.")
    ast = ast_seq[0]

    data_axes = ast['props']['data_axes']

    try:
        pieces = config
        pieces = [(pieces['tile_%d' % i][1] * pieces['tile_%d' % i][2])
                  for i in range(len(data_axes))]
    except:
        pieces = [1] * len(data_axes)
    for i in range(len(pieces)):
        assert data_axes[i]['range'] % pieces[i] == 0
        pieces[i] = data_axes[i]['range'] // pieces[i]

    assert 'injective' not in ast, "Unhandled injective case for graphcore."
    range_book = {}
    walk_in_ast(ast, 'root', scan_items, [ast, range_book])
    ast['props']['shard'] = {'nparts': pieces, 'book': range_book}

    # AST props: ast['props']['data_axes'], ast['props']['input_dict']
    for i in range(len(pieces)):
        assert data_axes[i]['range'] % pieces[
            i] == 0, "Axis sharding must be exactly divided, while requesting %d // %d." % (
                data_axes[i]['range'], pieces[i])
        data_axes[i]['range'] //= pieces[i]

    for k in ast['props']['input_dict']:
        input_item = ast['props']['input_dict'][k]
        sub_shape = []
        for it in range_book[k]:
            bias_diff = it[3] - it[2] + 1
            if it[1] < 0 or it[0] == 0:
                sub_shape.append(bias_diff)
            elif it[0] > 0:
                sub_shape.append(it[0] * (data_axes[it[1]]['range'] - 1) +
                                 bias_diff)
            else:
                raise Exception('Unhandled book case:', it)
        input_item['shape'] = sub_shape

    from antares.common import local_get_dir_file
    output_key = ast['props']['output_name']
    ast['props']['shard']['local_shape'] = [
        x['range'] for x in ast['props']['data_axes']
    ]
    with open(local_get_dir_file('range_book.json'), 'w') as fp:
        json.dump(ast['props']['shard'], fp)
    for k in global_input_dict:
        if k in ast['props']['input_dict']:
            global_input_dict[k] = ast['props']['input_dict'][k]

    assert len(global_output_dict) == 1
    for k in global_output_dict:
        global_output_dict[k]['shape'] = [
            x['range'] for x in ast['props']['data_axes']
        ]
        break