예제 #1
0
def optimize_OpLogicalOr(module, inst):
    # x or true -> true
    if inst.operands[1].inst.is_constant_value(True):
        return inst.operands[1].inst
    # x or false -> x
    if inst.operands[1].inst.is_constant_value(False):
        return inst.operands[0].inst
    # x or x -> x
    if inst.operands[0] == inst.operands[1]:
        return inst.operands[0].inst
    # undef or undef -> undef
    if (inst.operands[0].inst.op_name == 'OpUndef'
            and inst.operands[1].inst.op_name == 'OpUndef'):
        return inst.operands[0].inst
    # (not x) or (not y) -> not (x and y)
    if (inst.operands[0].inst.op_name == 'OpLogicalNot'
            and inst.operands[1].inst.op_name == 'OpLogicalNot'):
        op_id0 = inst.operands[0].inst.operands[0]
        op_id1 = inst.operands[1].inst.operands[0]
        or_inst = ir.Instruction(module, 'OpLogicalAnd', inst.type_id,
                                 [op_id0, op_id1])
        or_inst.insert_before(inst)
        not_inst = ir.Instruction(module, 'OpLogicalNot', inst.type_id,
                                  [or_inst.result_id])
        not_inst.insert_after(or_inst)
        return not_inst
    return inst
예제 #2
0
def canonicalize_inst(module, inst):
    """Canonicalize operand order if instruction is commutative.

    The canonical form is that a commutative instruction with one constant
    operand always has the constant as its second operand."""
    if inst.op_name == 'OpExtInst':
        extset_inst = inst.operands[0].inst
        assert extset_inst.op_name == 'OpExtInstImport'
        if extset_inst.operands[0] in ir.EXT_INST:
            ext_ops = ir.EXT_INST[extset_inst.operands[0]]
            if ext_ops[inst.operands[1]]['is_commutative']:
                new_inst = ir.Instruction(module, 'OpExtInst', inst.type_id, [
                    inst.operands[0], inst.operands[1], inst.operands[3],
                    inst.operands[2]
                ])
                new_inst.insert_before(inst)
                return new_inst
    elif (inst.is_commutative()
          and inst.operands[0].inst.op_name in ir.CONSTANT_INSTRUCTIONS
          and inst.operands[1].inst.op_name not in ir.CONSTANT_INSTRUCTIONS):
        new_inst = ir.Instruction(module, inst.op_name, inst.type_id,
                                  [inst.operands[1], inst.operands[0]])
        new_inst.insert_before(inst)
        return new_inst
    return inst
예제 #3
0
def create_id(module, token, tag, type_id=None):
    """Create the 'real' ID from an ID token.

    The IDs are generalized; it accepts e.g. type names such as 'f32'
    where the ID for the 'OpTypeFloat' is returned. Valid generalized
    IDs are:
      * types
      * integer scalar constants (the value can be decimal, binary, or
        hexadecimal)
    """
    if token in module.symbol_name_to_id:
        return module.symbol_name_to_id[token]
    elif tag == 'ID':
        assert token[0] == '%'
        if not token[1].isdigit():
            new_id = ir.Id(module)
            module.symbol_name_to_id[token] = new_id
            name = token[1:]
            inst = ir.Instruction(module, 'OpName', None, [new_id, name])
            module.insert_global_inst(inst)
        else:
            value = int(token[1:])
            if value in module.value_to_id:
                return module.value_to_id[value]
            new_id = ir.Id(module, value)
            module.value_to_id[value] = new_id
        return new_id
    elif tag == 'INT' or token in ['true', 'false']:
        value = get_scalar_value(token, tag, type_id)
        inst = module.get_constant(type_id, value)
        return inst.result_id
    elif token in module.type_name_to_id:
        return module.type_name_to_id[token]
    else:
        return get_or_create_type(module, token)
예제 #4
0
def parse_instruction(binary, module):
    """Parse one instruction."""
    op_name, op_format = binary.get_next_opcode()
    operands = []
    inst_type_id = None
    if op_format['type']:
        inst_type_id = parse_id(binary, module)
    result_id = None
    if op_format['result']:
        result_id = parse_id(binary, module)
        if result_id.inst is not None:
            raise ParseError('ID ' + str(result_id) + ' is already defined')
    for kind in op_format['operands']:
        operands = operands + parse_operand(binary, module, kind)
    binary.expect_eol()

    if op_name == 'OpFunction':
        return ir.Function(module,
                           operands[0],
                           operands[1],
                           result_id=result_id)
    else:
        return ir.Instruction(module,
                              op_name,
                              inst_type_id,
                              operands,
                              result_id=result_id)
예제 #5
0
def update_conditional_branch(module, inst, dest_id):
    """Change the OpBranchConditional or OpSwitch to a branch to dest_id."""
    assert inst.op_name == 'OpBranchConditional' or inst.op_name == 'OpSwitch'
    basic_block = inst.basic_block
    branch_inst = ir.Instruction(module, 'OpBranch', None, [dest_id])
    inst.replace_with(branch_inst)
    if basic_block.insts[-2].op_name in ['OpSelectionMerge', 'OpLoopMerge']:
        basic_block.insts[-2].destroy()
예제 #6
0
def optimize_OpBitcast(module, inst):
    operand_inst = inst.operands[0].inst
    # bitcast(bitcast(x)) -> bitcast(x) or x
    if operand_inst.op_name == 'OpBitcast':
        if inst.type_id == operand_inst.operands[0].inst.type_id:
            return operand_inst.operands[0].inst
        else:
            new_inst = ir.Instruction(module, 'OpBitcast', inst.type_id,
                                      [operand_inst.operands[0]])
            new_inst.copy_decorations(inst)
            new_inst.insert_before(inst)
            return new_inst
    # bitcast(undef) -> undef
    if operand_inst.op_name == 'OpUndef':
        new_inst = ir.Instruction(module, 'OpUndef', inst.type_id, [])
        new_inst.insert_before(inst)
        return new_inst
    return inst
예제 #7
0
def parse_instruction(lexer, module):
    """Parse one instruction."""
    _, tag = lexer.get_next_token(peek=True)
    if tag == 'ID':
        result_id = parse_id(lexer, module)
        if result_id.inst is not None:
            id_name = get_id_name(module, result_id)
            raise ParseError(id_name + ' is already defined')
        lexer.get_next_token('=')
    else:
        result_id = None
    op_name, tag = lexer.get_next_token()
    if tag != 'NAME':
        raise ParseError('Expected an operation name')
    if op_name not in ir.INST_FORMAT:
        raise ParseError('Invalid operation ' + op_name)
    op_format = ir.INST_FORMAT[op_name]
    if op_format['type']:
        type_id = parse_type(lexer, module)
    else:
        type_id = None

    parse_decorations(lexer, module, result_id, op_name)
    if op_name == 'OpExtInst':
        operands = parse_extinst_operands(lexer, module, type_id)
    else:
        operands = parse_operands(lexer, module, op_format, type_id)
    lexer.done_with_line()

    if op_name == 'OpFunction':
        function = ir.Function(module, operands[0], operands[1],
                               result_id=result_id)
        module.inst_to_line[function.inst] = lexer.line_no
        module.inst_to_line[function.end_inst] = lexer.line_no
        return function
    elif op_name == 'OpLabel':
        basic_block = ir.BasicBlock(module, result_id)
        module.inst_to_line[basic_block.inst] = lexer.line_no
        return basic_block
    else:
        inst = ir.Instruction(module, op_name, type_id, operands,
                              result_id=result_id)
        module.inst_to_line[inst] = lexer.line_no
        return inst
예제 #8
0
def optimize_OpIMul(module, inst):
    # x * 0 -> 0
    if inst.operands[1].inst.is_constant_value(0):
        return inst.operands[1].inst
    # x * 1 -> 1
    if inst.operands[1].inst.is_constant_value(1):
        return inst.operands[0].inst
    # x * -1 -> -x
    if inst.operands[1].inst.is_constant_value(-1):
        new_inst = ir.Instruction(module, 'OpSNegate', inst.type_id,
                                  [inst.operands[0]])
        new_inst.insert_before(inst)
        return new_inst
    # x * undef -> undef
    if inst.operands[1].inst.op_name == 'OpUndef':
        return inst.operands[1].inst
    # undef * x -> undef
    if inst.operands[0].inst.op_name == 'OpUndef':
        return inst.operands[0].inst
    return inst
예제 #9
0
def optimize_OpLogicalNotEqual(module, inst):
    # NotEqual(x, false) -> x
    if inst.operands[1].inst.is_constant_value(False):
        return inst.operands[0].inst
    # NotEqual(x, true) -> not(x)
    if inst.operands[1].inst.is_constant_value(True):
        new_inst = ir.Instruction(module, 'OpLogicalNot', inst.type_id,
                                  [inst.operands[0]])
        new_inst.insert_before(inst)
        return new_inst
    # NotEqual(x, x) -> false
    if inst.operands[0] == inst.operands[1]:
        return module.get_constant(inst.type_id, False)
    # NotEqual(x, undef) -> undef
    if inst.operands[1].inst.op_name == 'OpUndef':
        return inst.operands[1].inst
    # NotEqual(undef, x) -> undef
    if inst.operands[0].inst.op_name == 'OpUndef':
        return inst.operands[0].inst
    return inst
예제 #10
0
def parse_function_definition(lexer, module):
    """Parse the 'definition' line of a pretty-printed function."""
    lexer.get_next_token('define')
    return_type = parse_type(lexer, module)
    result_id = parse_id(lexer, module)
    parameters = parse_parameters(lexer, module)

    if result_id.inst is not None:
        id_name = get_id_name(module, result_id)
        raise ParseError(id_name + ' is already defined')

    operands = [return_type] + [param[0] for param in parameters]
    function_type_inst = module.get_global_inst('OpTypeFunction', None,
                                                operands)
    function = ir.Function(module, [], function_type_inst.result_id,
                           result_id=result_id) # XXX
    for param_type, param_id in parameters:
        param_inst = ir.Instruction(module, 'OpFunctionParameter', param_type,
                                    [], result_id=param_id)
        function.append_parameter(param_inst)

    return function
예제 #11
0
def parse_decorations(lexer, module, variable_name, op_name):
    """Parse pretty-printed decorations."""
    while True:
        token, _ = lexer.get_next_token(peek=True, accept_eol=True)
        if token == '':
            return
        elif token not in spirv.spv['Decoration']:
            return

        # XXX We should check that the decorations are valid for the
        # operation.
        #
        # In particular 'Uniform' is both a decoration and a storage class,
        # so instructions that have 'StorageClass' as first operand must
        # not parse 'Uniform' as a decoration (and 'Uniform' is not a valid
        # decoration for those operations).  At this as a special case
        # here until the real decoration check has been implemented.
        if op_name in ['OpTypePointer', 'OpVariable'] and token == 'Uniform':
            return

        decoration, _ = lexer.get_next_token()
        if not decoration in spirv.spv['Decoration']:
            raise ParseError('Unknown decoration ' + decoration)
        token, _ = lexer.get_next_token(peek=True, accept_eol=True)
        operands = [variable_name, decoration]
        if token == '(':
            lexer.get_next_token()
            while True:
                operands.append(parse_literal_number(lexer))
                token, _ = lexer.get_next_token()
                if token == ')':
                    break
                if token != ',':
                    raise ParseError('Syntax error in decoration')
        inst = ir.Instruction(module, 'OpDecorate', None, operands)
        module.insert_global_inst(inst)
예제 #12
0
def optimize_OpCompositeConstruct(module, inst):
    # Code of the form
    #   %20 = OpCompositeExtract f32 %19, 0
    #   %21 = OpCompositeExtract f32 %19, 1
    #   %22 = OpCompositeExtract f32 %19, 2
    #   %23 = OpCompositeConstruct <3 x f32> %20, %21, %22
    # can be changed to a OpVectorShuffle if all OpCompositeExtract
    # comes from one or two vectors.
    if inst.type_id.inst.op_name == 'OpTypeVector':
        sources = []
        for operand in inst.operands:
            if operand.inst.op_name != 'OpCompositeExtract':
                break
            src_inst = operand.inst.operands[0].inst
            if src_inst.result_id not in sources:
                if src_inst.type_id.inst.op_name != 'OpTypeVector':
                    break
                sources.append(src_inst.result_id)
            if len(sources) > 2:
                break
        else:
            vec1_id = sources[0]
            vec2_id = sources[0] if len(sources) == 1 else sources[1]
            vec1_len = vec1_id.inst.type_id.inst.operands[1]
            vecshuffle_operands = [vec1_id, vec2_id]
            for operand in inst.operands:
                idx = operand.inst.operands[1]
                if operand.inst.operands[0] != vec1_id:
                    idx = idx + vec1_len
                vecshuffle_operands.append(idx)
            new_inst = ir.Instruction(module, 'OpVectorShuffle', inst.type_id,
                                      vecshuffle_operands)
            new_inst.copy_decorations(inst)
            new_inst.insert_before(inst)
            return new_inst
    return inst
예제 #13
0
def optimize_variable(module, func, var_inst):
    """Promote/eliminate the var_inst variable if possible."""
    # Delete variable if it is not used.
    if not var_inst.result_id.uses:
        var_inst.destroy()
        return

    # We only handle simple loads and stores.
    for inst in var_inst.uses():
        if inst.op_name not in ['OpLoad', 'OpStore']:
            return

    # Eliminate loads/store instructions for the variable
    pred = calculate_pred(func)
    exit_value = {}
    phi_nodes = []
    undef_insts = []
    var_type_id = var_inst.type_id.inst.operands[1]
    for basic_block in func.basic_blocks:
        # Get the variable's value at start of the basic block.
        if not pred[basic_block]:
            stored_inst = None
        elif len(pred[basic_block]) == 1:
            stored_inst = exit_value[pred[basic_block][0]]
        else:
            stored_inst = ir.Instruction(module, 'OpPhi', var_type_id, [])
            basic_block.prepend_inst(stored_inst)
            phi_nodes.append(stored_inst)

        # Eliminate loads/store instructions.
        ordered_uses = [
            inst for inst in basic_block.insts
            if inst in var_inst.result_id.uses
        ]
        for inst in ordered_uses:
            if inst.op_name == 'OpLoad':
                if stored_inst is None:
                    stored_inst = ir.Instruction(module, 'OpUndef',
                                                 inst.type_id, [])
                    undef_insts.append(stored_inst)
                    stored_inst.insert_before(inst)
                inst.replace_uses_with(stored_inst)
                inst.destroy()
            elif inst.op_name == 'OpStore':
                stored_inst = inst.operands[1].inst
                inst.destroy()

        # Save the variable's value at end of the basic block.
        exit_value[basic_block] = stored_inst

    # Add operands to the phi-nodes.
    for inst in phi_nodes:
        for pred_bb in pred[inst.basic_block]:
            if exit_value[pred_bb] is None:
                undef_inst = ir.Instruction(module, 'OpUndef', var_type_id, [])
                undef_insts.append(undef_inst)
                last_insts = pred_bb.insts[-2:]
                if (last_insts[0].op_name
                        in ['OpLoopMerge', 'OpSelectionMerge']
                        or last_insts[0].op_name in ir.BRANCH_INSTRUCTIONS):
                    undef_inst.insert_before(last_insts[0])
                else:
                    undef_inst.insert_after(last_insts[0])
                exit_value[pred_bb] = undef_inst
            inst.add_to_phi(exit_value[pred_bb], pred_bb.inst)

    # Destroy obviously dead instructions.
    for inst in reversed(phi_nodes):
        if not inst.result_id.uses:
            inst.destroy()
    for inst in undef_insts:
        if not inst.result_id.uses:
            inst.destroy()
    var_inst.destroy()
예제 #14
0
def optimize_OpVectorShuffle(module, inst):
    vec1_inst = inst.operands[0].inst
    vec2_inst = inst.operands[1].inst
    components = inst.operands[2:]

    # VectorShuffle of undef -> undef
    if vec1_inst.op_name == 'OpUndef' and vec2_inst.op_name == 'OpUndef':
        new_inst = ir.Instruction(module, 'OpUndef', inst.type_id, [])
        new_inst.insert_before(inst)
        return new_inst

    # Change vector shuffles "A, unused" or "unused, A" to "A, A" where
    # the second operand is not used (and change to OpUndef if no elements
    # of the input vectors are used).
    #
    # We use this form for swizzles instead of using an OpUndef for the
    # unused vector in order to avoid adding an extra instruction for the
    # OpUndef. This form also makes the constant folder handle the shuffle
    # for constant A without needing to special the case where one operand
    # is constant, and one is OpUndef.
    using_vec1 = False
    using_vec2 = False
    vec1_type_inst = vec1_inst.type_id.inst
    assert vec1_type_inst.op_name == 'OpTypeVector'
    vec1_len = vec1_type_inst.operands[1]
    for component in components:
        if component != 0xffffffff:
            if component < vec1_len:
                using_vec1 = True
            else:
                using_vec2 = True
    if not using_vec1 and not using_vec2:
        new_inst = ir.Instruction(module, 'OpUndef', inst.type_id, [])
        new_inst.insert_before(inst)
        return new_inst
    elif not using_vec2:
        vec2_inst = vec1_inst
    elif not using_vec1:
        for idx, component in enumerate(components):
            if component != 0xffffffff:
                components[idx] = component - vec1_len
        vec1_inst = vec2_inst

    # Change shuffle "A, A" so that only the first is used.
    if vec1_inst == vec2_inst:
        vec1_type_inst = vec1_inst.type_id.inst
        assert vec1_type_inst.op_name == 'OpTypeVector'
        vec1_len = vec1_type_inst.operands[1]
        for idx, component in enumerate(components):
            if component != 0xffffffff and component >= vec1_len:
                components[idx] = component - vec1_len

    # Eliminate identity swizzles.
    if vec1_inst == vec2_inst:
        if inst.type_id == vec1_inst.type_id:
            for idx, component in enumerate(components):
                if component != 0xffffffff and component != idx:
                    break
            else:
                return vec1_inst

    # Create new inst if we have changed the instruction.
    operands = [vec1_inst.result_id, vec2_inst.result_id] + components
    if operands != inst.operands:
        new_inst = ir.Instruction(module, 'OpVectorShuffle', inst.type_id,
                                  operands)
        new_inst.copy_decorations(inst)
        new_inst.insert_before(inst)
        return new_inst

    return inst