Beispiel #1
0
  def transform_Assign(self, stmt):
    if isinstance(stmt.rhs, syntax.Closure):
      self.closure_vars[stmt.lhs.name] = stmt.rhs

    if isinstance(stmt.rhs, syntax.Adverb):
      new_rhs = self.transform_expr(stmt.rhs)
      stmt.lhs.type = new_rhs.type
      self.type_env[stmt.lhs.name] = stmt.lhs.type
      return syntax.Assign(stmt.lhs, new_rhs)
    elif len(self.adverbs_visited) > 0:
      fv_names = free_vars(stmt.rhs)
      depths = self.get_depths_list(fv_names)
      map_tree = [syntax.Map for _ in depths]
      inner_body = [stmt, syntax.Return(stmt.lhs)]
      nested_args, unpack_fn = \
          self.gen_unpack_tree(map_tree, depths, fv_names, inner_body,
                               self.fn.type_env)
      new_rhs = syntax.Call(unpack_fn, nested_args)
      stmt.lhs.type = new_rhs.type
      self.type_env[stmt.lhs.name] = new_rhs.type
      return syntax.Assign(stmt.lhs, new_rhs)
    else:
      # Do nothing if we're not inside a nesting of tiled adverbs
      return stmt
Beispiel #2
0
    def gen_unpack_fn(depth_idx, arg_order):
      if depth_idx >= len(depths):
        if reg_tiling:
          return inner
        else:
          # For each stmt in body, add its lhs free vars to the type env
          inner_type_env = copy.copy(type_env)
          return_t = Int32 # Dummy type
          for s in inner:
            if isinstance(s, syntax.Assign):
              lhs_names = free_vars(s.lhs)
              lhs_types = [type_env[name] for name in lhs_names]
              for name, t in zip(lhs_names, lhs_types):
                inner_type_env[name] = t
            elif isinstance(s, syntax.Return):
              if isinstance(s.value, str):
                return_t = type_env[s.value.name]
              else:
                return_t = s.value.type

          # The innermost function always uses all the variables
          input_types = [type_env[arg] for arg in arg_order]
          fn = syntax.TypedFn(name = names.fresh("inner_block"),
                              arg_names = tuple([name for name in v_names]),
                              body = inner,
                              input_types = input_types,
                              return_type = return_t,
                              type_env = inner_type_env)
          return fn
      else:
        # Get the current depth
        depth = depths[depth_idx]

        # Order the arguments for the current depth, i.e. for the nested fn
        cur_arg_names, fixed_arg_names = order_args(depth)
        nested_arg_names = fixed_arg_names + cur_arg_names

        # Make a type env for this function based on the number of expansions
        # left for each arg
        adv_args = self.adverb_args[depth_idx]
        if reg_tiling:
          new_adverb = adverb_tree[depth_idx](fn = adv_args.fn,
                                              args = adv_args.args,
                                              axes = adv_args.axes,
                                              fixed_tile_size = True)
        else:
          new_adverb = adverb_tree[depth_idx](fn = adv_args.fn,
                                              args = adv_args.args,
                                              axis = adv_args.axis)

        # Increase the rank of each arg by the number of nested expansions
        # (i.e. the expansions of that arg that occur deeper in the nesting)
        new_type_env = {}
        if reg_tiling:
          for arg in nested_arg_names:
            new_type_env[arg] = inner.type_env[arg]
        else:
          for arg in nested_arg_names:
            exps = self.get_expansions(arg)
            rank_increase = 0
            for i, e in enumerate(exps):
              if e >= depth:
                rank_increase = len(exps) - i
                break
            new_type_env[arg] = \
                array_type.increase_rank(type_env[arg], rank_increase)

        cur_arg_types = [new_type_env[arg] for arg in cur_arg_names]
        fixed_arg_types = [new_type_env[arg] for arg in fixed_arg_names]

        # Generate the nested function with the proper arg order and wrap it
        # in a closure
        nested_fn = gen_unpack_fn(depth_idx+1, nested_arg_names)
        nested_args = [syntax.Var(name, type = t)
                       for name, t in zip(cur_arg_names, cur_arg_types)]
        nested_fixed_args = \
            [syntax.Var(name, type = t)
             for name, t in zip(fixed_arg_names, fixed_arg_types)]
        nested_closure = self.closure(nested_fn, nested_fixed_args)

        # Make an adverb that wraps the nested fn
        new_adverb.fn = nested_closure
        new_adverb.args = nested_args
        return_t = nested_fn.return_type
        if isinstance(new_adverb, syntax.Reduce):
          if reg_tiling:
            ds = copy.copy(depths)
            ds.remove(depth)
            new_adverb.combine = self.unpack_combine(adv_args.combine, ds)
          else:
            new_adverb.combine = adv_args.combine
          new_adverb.init = adv_args.init
        elif not reg_tiling:
          return_t = array_type.increase_rank(nested_fn.return_type, 1)
        new_adverb.type = return_t

        # Add the adverb to the body of the current fn and return the fn
        name = names.fresh("reg_tile" if reg_tiling else "intermediate_depth")
        arg_types = [new_type_env[arg] for arg in arg_order]
        fn = syntax.TypedFn(name = name,
                            arg_names = arg_order,
                            body = [syntax.Return(new_adverb)],
                            input_types = arg_types,
                            return_type = return_t,
                            type_env = new_type_env)
        return fn