Esempio n. 1
0
class Contract:
    def __init__(self,
                 source: str,
                 owner: AddressLike,
                 nonce: Optional[bytes] = None):
        self._source = str(source)
        self._digest = _compute_digest(self._source)
        self._owner = Address(owner)
        self._nonce = bytes(urandom(8)) if nonce is None else nonce

        hasher = hashlib.sha256()
        hasher.update(bytes(self._owner))
        hasher.update(self._nonce)

        self._address = Address(hasher.digest())

        # Etch parser for analysing contract
        self._parser = EtchParser(self._source)

        # Generate set of action and query entry points
        entries = self._parser.entry_points(['init', 'action', 'query'])
        self._actions = list(set(entries.get('action', [])))
        self._queries = list(set(entries.get('query', [])))

        init = entries.get('init', [])
        if len(init) > 1:
            raise RuntimeError(
                'Contract may not have more than one @init function, found: {}'
                .format(', '.join(init)))
        self._init = init[0] if len(init) else None

    @property
    def name(self) -> str:
        return str(self.address)

    def dumps(self):
        return json.dumps(self._to_json_object())

    def dump(self, fp):
        return json.dump(self._to_json_object(), fp)

    @classmethod
    def loads(cls, s):
        return cls._from_json_object(json.loads(s))

    @classmethod
    def load(cls, fp):
        return cls._from_json_object(json.load(fp))

    @property
    def owner(self) -> Address:
        return self._owner

    @property
    def source(self) -> str:
        return self._source

    @property
    def digest(self) -> str:
        return bytes(self._digest).hex()

    @property
    def nonce(self) -> str:
        return base64.b64encode(self._nonce).decode()

    @property
    def nonce_bytes(self) -> bytes:
        return self._nonce

    @property
    def address(self) -> Address:
        return self._address

    @property
    def encoded_source(self) -> str:
        return base64.b64encode(self.source.encode('ascii')).decode()

    def create_as_tx(self, api: LedgerApi, from_address: AddressLike, fee: int,
                     signers: Iterable[Identity]) -> 'Transaction':
        # build the shard mask for the
        shard_mask = self._build_shard_mask(api.server.num_lanes(), self._init)
        tx = ContractTxFactory.create(Address(from_address),
                                      self,
                                      fee,
                                      signers,
                                      shard_mask=shard_mask)
        api.set_validity_period(tx)

        return tx

    def create(self, api: LedgerApi, owner: Entity, fee: int):

        # build the shard mask for the
        shard_mask = self._build_shard_mask(api.server.num_lanes(), self._init,
                                            [self._owner])
        return api.contracts.create(owner, self, fee, shard_mask=shard_mask)

    def query(self, api: LedgerApi, name: str, **kwargs):
        # TODO(WK): Reinstate without breaking contract-to-contract calls
        # if name not in self._queries:
        #     raise RuntimeError(
        #         '{} is not an valid query name. Valid options are: {}'.format(name, ','.join(list(self._queries))))

        # make the required query on the API
        success, response = api.contracts.query(self.address, name, **kwargs)

        if not success:
            if response is not None and "msg" in response:
                raise RuntimeError('Failed to make requested query: ' +
                                   response["msg"])
            else:
                raise RuntimeError(
                    'Failed to make requested query with no error message.')

        return response['result']

    def action(self, api: LedgerApi, name: str, fee: int, signer: Entity,
               *args):

        # TODO(WK): Reinstate without breaking contract-to-contract calls
        # if name not in self._actions:
        #     raise RuntimeError(
        #         '{} is not an valid action name. Valid options are: {}'.format(name, ','.join(list(self._actions))))

        shard_mask = self._build_shard_mask(api.server.num_lanes(), name,
                                            list(args))

        # previous versions of the API provided the list of signers as an input, this was mostly done as a work
        # around for the multi-signature support. This has been deprecated, however, the compatibility is kept for the
        # single entity case
        signer = self._convert_to_single_entity(signer)

        return api.contracts.action(self.address,
                                    name,
                                    fee,
                                    signer,
                                    *args,
                                    shard_mask=shard_mask)

    def _build_shard_mask(self, num_lanes: int, name: Optional[str],
                          arguments: List[Any]) -> BitVector:
        try:
            resource_addresses = [
                'fetch.contract.state.{}'.format(str(self.address)),
            ]

            # only process the init functions resources if this function is actually present
            if name is not None:
                variables = self._parser.used_globals_to_addresses(
                    name, arguments)
                for variable in variables:
                    resource_addresses.append(
                        ShardMask.state_to_address(str(self.address),
                                                   variable))

            shard_mask = ShardMask.resources_to_shard_mask(
                resource_addresses, num_lanes)

        except (UnparsableAddress, UseWildcardShardMask, EtchParserError,
                AssertionError) as ex:
            logging.debug('Parser Error: {}'.format(ex))
            logging.warning(
                "Couldn't auto-detect used shards, using wildcard shard mask")
            shard_mask = BitVector()

        return shard_mask

    @staticmethod
    def _convert_to_single_entity(value: Union[Iterable[Entity], Entity]):
        if isinstance(value, Entity):
            return value

        try:
            # attempt to create a list of items from the input
            converted = list(iter(value))

            if len(converted) == 1 and isinstance(converted[0], Entity):
                return converted[0]
        except:
            pass

        raise ValueError('Unable to extract single entity from input value')

    @staticmethod
    def _from_json_object(obj):
        assert obj['version'] == 1

        source = base64.b64decode(obj['source']).decode()
        owner = obj['owner']
        nonce = base64.b64decode(obj['nonce'].encode())

        sc = Contract(source, owner, nonce)

        return sc

    def _to_json_object(self):
        return {
            'version': 1,
            'nonce': self.nonce,
            'owner': None if self._owner is None else str(self._owner),
            'source': self.encoded_source
        }
Esempio n. 2
0
class ParserTests(unittest.TestCase):
    def setUp(self) -> None:
        try:
            self.parser = EtchParser(CONTRACT_TEXT)
            self.assertIsNotNone(self.parser._parsed_tree,
                                 "Parsed tree missing when code passed")
        except ParseError as e:
            self.fail("Failed to parse contract text: \n" + str(e))

    def test_grammar(self):
        """Check that grammer compiles"""
        # TODO: Grammar is loaded from a file, which may impact unit test performance
        try:
            parser = EtchParser()
            self.assertIsNone(parser._parsed_tree,
                              "Parsed tree present when no code passed")
        except GrammarError as e:
            self.fail("Etch grammar failed to load with: \n" + str(e))

    def test_get_functions(self):
        """Check that functions properly identified"""
        functions = self.parser.get_functions()

        # Check all functions found
        function_dict = {f.name: f for f in functions}
        self.assertTrue(
            all(n in function_dict.keys()
                for n in ['setup', 'transfer', 'balance', 'sub']))

        # Check transfer parsed
        self.assertEqual(function_dict['transfer'].annotation, 'action')
        self.assertEqual(function_dict['transfer'].lines, (11, 24))
        self.assertIsNotNone(function_dict['transfer'].code_block)

        # Check return value correctly parsed
        self.assertEqual(function_dict['balance'].return_type, 'UInt64')

        # Check parameter block correctly parsed
        self.assertEqual(len(function_dict['setup'].parameters), 1)
        self.assertIsNone(function_dict['setup'].parameters[0].value)
        self.assertEqual(function_dict['setup'].parameters[0].name, 'owner')
        self.assertEqual(function_dict['setup'].parameters[0].ptype, 'Address')

    def test_entry_points(self):
        entry_points = self.parser.entry_points()
        self.assertIn('init', entry_points)
        self.assertIn('action', entry_points)
        self.assertIn('query', entry_points)

        self.assertEqual(entry_points['init'], ['setup'])
        self.assertEqual(entry_points['action'], ['transfer'])
        self.assertEqual(entry_points['query'], ['balance'])

    def test_globals_declared(self):
        glob_decl = self.parser.globals_declared()
        self.assertEqual(set(glob_decl.keys()),
                         {'balance_state', 'owner_name'})
        self.assertEqual(glob_decl['balance_state'].name, 'balance_state')
        self.assertEqual(glob_decl['balance_state'].gtype, 'UInt64')
        self.assertEqual(glob_decl['balance_state'].is_sharded, True)

        self.assertEqual(glob_decl['owner_name'].name, 'owner_name')
        self.assertEqual(glob_decl['owner_name'].gtype, 'String')
        self.assertEqual(glob_decl['owner_name'].is_sharded, False)

    def test_globals_used(self):
        """Test accurate parsing of globals used in entry points"""
        # Test accurate parsing of declared globals
        glob_used = self.parser.globals_used('setup', ['abc'])
        self.assertEqual(len(glob_used), 1)
        self.assertEqual(len(glob_used[0]), 2)
        self.assertEqual(glob_used[0][0], 'balance_state')
        self.assertEqual(glob_used[0][1].value, 'abc')
        self.assertEqual(glob_used[0][1].name, 'owner')

    def test_global_addresses(self):
        """Test accurate parsing of globals used in entry points"""

        with patch('logging.warning') as mock_warn:
            glob_addresses = self.parser.used_globals_to_addresses(
                'transfer', ['abc', 'def', 100])
            self.assertEqual(mock_warn.call_count, 1)
        self.assertEqual(len(glob_addresses), 5)
        # Unsharded use statement
        self.assertEqual(glob_addresses[0], 'owner_name')
        # Sharded use statements
        self.assertEqual(glob_addresses[1], 'balance_state.abc')  # Parameter
        self.assertEqual(glob_addresses[2], 'balance_state.def')  # Parameter
        self.assertEqual(glob_addresses[3],
                         'balance_state.constant_string')  # String constant
        self.assertEqual(glob_addresses[4],
                         'balance_state.prefix.def')  # String concatenation

    def test_scope(self):
        """Tests which instructions are allowed at each scope"""
        # Regular instructions are not allowed at global scope
        with patch('logging.warning') as mock_warn:
            self.assertFalse(self.parser.parse("var a = 5;"))
            self.assertEqual(mock_warn.call_count, 2)

        # Allowed at global scope
        try:
            # Persistent global declaration
            self.assertTrue(
                self.parser.parse("persistent sharded balance_state : UInt64;")
                is not False)
            self.assertTrue(
                self.parser.parse("persistent owner : String;") is not False)

            # Functions
            self.assertTrue(
                self.parser.parse("""function a(owner : String)
                var b = owner;
                endfunction""") is not False)

            # Annotated functions
            self.assertTrue(
                self.parser.parse("""@action
                function a(owner : String)
                var b = owner;
                endfunction""") is not False)

            # Comments
            self.assertTrue(self.parser.parse("// A comment") is not False)
        except UnexpectedCharacters as e:
            self.fail("Etch parsing of top level statement failed: \n" +
                      str(e))

    def test_builtins(self):
        """Tests for correct parsing of all supported builtin types"""
        parser = EtchParser()
        int_types = ['Int' + str(x) for x in [8, 16, 32, 64, 256]]
        uint_types = ['UInt' + str(x) for x in [8, 16, 32, 64, 256]]

        float_types = ['Float' + str(x) for x in [32, 64]]
        fixed_types = ['Fixed' + str(x) for x in [32, 64]]

        # Test declaration of numerical types
        for t in int_types + uint_types:
            tree = self.parser.parse(
                FUNCTION_BLOCK.format("var a : {};".format(t)))
            tree = next(tree.find_data("instruction"))
            self.assertEqual(tree.children[0].data, 'declaration')
            self.assertEqual(tree.children[0].children[1].type, 'BASIC_TYPE')
            self.assertEqual(tree.children[0].children[1].value, t)

        for t in float_types:
            tree = self.parser.parse(
                FUNCTION_BLOCK.format("var a : {};".format(t)))
            tree = next(tree.find_data("instruction"))
            self.assertEqual(tree.children[0].data, 'declaration')
            self.assertEqual(tree.children[0].children[1].type, 'FLOAT_TYPE')
            self.assertEqual(tree.children[0].children[1].value, t)

        for t in fixed_types:
            tree = self.parser.parse(
                FUNCTION_BLOCK.format("var a : {};".format(t)))
            tree = next(tree.find_data("instruction"))
            self.assertEqual(tree.children[0].data, 'declaration')
            self.assertEqual(tree.children[0].children[1].type, 'FIXED_TYPE')
            self.assertEqual(tree.children[0].children[1].value, t)

        # Test declaration of other types
        other_types = ['Boolean', 'String']
        for t in other_types:
            tree = self.parser.parse(
                FUNCTION_BLOCK.format("var a : {};".format(t)))
            tree = next(tree.find_data("instruction"))
            self.assertEqual(tree.children[0].data, 'declaration')
            self.assertEqual(tree.children[0].children[1].type, 'NAME')
            self.assertEqual(tree.children[0].children[1].value, t)

        # TODO: Test these in a meaningful way, beyond simply that they parse
        # Test declaration of array
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("var myArray = Array<Int32>(5);"))
        # Test assignment to array
        tree = self.parser.parse(FUNCTION_BLOCK.format("myArray[0] = 5;"))
        # Test assignment from array
        tree = self.parser.parse(FUNCTION_BLOCK.format("b = myArray[0];"))
        tree = self.parser.parse(FUNCTION_BLOCK.format("var b = myArray[0];"))

        # As above, for map type
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("var myArray = Map<String, Int32>(5);"))
        # Test assignment to array
        tree = self.parser.parse(FUNCTION_BLOCK.format("myArray['test'] = 5;"))
        # Test assignment from array
        tree = self.parser.parse(FUNCTION_BLOCK.format("b = myArray['test'];"))
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("var b = myArray['test'];"))

    def test_instantiation(self):
        """Tests for correct parsing of valid variable instantiation"""
        # Check that the following parse without error
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("var b = get();"))  # Untyped instantiation
        tree = self.parser.parse(
            FUNCTION_BLOCK.format(
                "var b : UInt64 = get();"))  # Typed instantiation

    def test_template(self):
        """Tests for correct parsing of template variables"""
        tree = self.parser.parse(FUNCTION_BLOCK.format("a = State<UInt64>();"))
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("a = State<UInt64, UInt64>();"))

        # Test function parsing with template parameters
        tree = self.parser.parse(
            """function a(b : Array<StructuredData>) : Int32
        var c : State<UInt32>;
        endfunction
        """)
        functions = self.parser.get_functions()

        # Check that argument list correctly parsed
        self.assertEqual(functions[0].parameters[0].name, 'b')
        self.assertEqual(functions[0].parameters[0].ptype,
                         'Array<StructuredData>')

    def test_functions(self):
        """Tests correct detection of non-entry-point functions"""
        self.assertEqual(self.parser.subfunctions(), ['sub'])

    def test_class_function(self):
        """Tests correct ingestion of functions"""
        # Test minimal function
        tree = self.parser.parse("""function init()
        endfunction""")

        func = Function.from_tree(next(tree.find_data('function')))

        self.assertIsNone(func.annotation)
        self.assertIsNone(func.code_block)
        self.assertIsNone(func.return_type)
        self.assertEqual(func.name, 'init')
        self.assertEqual(func.parameters, [])

        # Test Function parsing from_tree
        tree = self.parser.parse("""
        @action
        function a(b : UInt64) : String
        return b;
        endfunction
        """)
        func = Function.from_tree(next(tree.find_data('annotation')))
        self.assertEqual(func.name, 'a')
        self.assertEqual(func.return_type, 'String')
        self.assertEqual(func.annotation, 'action')
        self.assertEqual(func.parameters[0].name, 'b')
        self.assertEqual(func.parameters[0].ptype, 'UInt64')

        # Test all_from_tree
        tree = self.parser.parse("""
        @action
        function a(b : UInt64) : String
        return 'test';
        endfunction
        
        function c(d: UInt64): String
        return 'test2';
        endfunction
        """)

        funcs = Function.all_from_tree(tree)
        self.assertEqual(funcs[0].name, 'a')
        self.assertEqual(funcs[0].return_type, 'String')
        self.assertEqual(funcs[0].annotation, 'action')
        self.assertEqual(funcs[0].parameters[0].name, 'b')
        self.assertEqual(funcs[0].parameters[0].ptype, 'UInt64')

        self.assertEqual(funcs[1].name, 'c')
        self.assertEqual(funcs[1].return_type, 'String')
        self.assertIsNone(funcs[1].annotation)
        self.assertEqual(funcs[1].parameters[0].name, 'd')
        self.assertEqual(funcs[1].parameters[0].ptype, 'UInt64')

    def test_nested_function_call(self):
        """Check that nested function calls are supported by parser"""
        try:
            tree = self.parser.parse(NESTED_FUNCTION)
            self.assertTrue(tree is not False)
        except:
            self.fail("Parsing of dot nested function calls failed")

    def test_expressions(self):
        """Check that common expressions parse correctly"""
        # Instantiation
        tree = self.parser.parse(FUNCTION_BLOCK.format("var a = 1i32;"))
        # Binary operation
        tree = self.parser.parse(FUNCTION_BLOCK.format("var a = 1i32 + 2i32;"))
        # Pre-unary operation
        tree = self.parser.parse(FUNCTION_BLOCK.format("var a = - 2i32;"))
        # Post-unary operation
        tree = self.parser.parse(FUNCTION_BLOCK.format("var a = 2i32++;"))
        # Comparison operation
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("var a = 2i32 == 3i32;"))
        # Type cast
        tree = self.parser.parse(FUNCTION_BLOCK.format("var a = Int64(3i32);"))

    def test_assignments(self):
        """Check successful parsing of assignment operators"""
        FB_WITH_DECLARATION = FUNCTION_BLOCK.format("var a : Int64; {}")
        tree = self.parser.parse(FB_WITH_DECLARATION.format("a += 5;"))
        tree = self.parser.parse(FB_WITH_DECLARATION.format("a -= 5;"))
        tree = self.parser.parse(FB_WITH_DECLARATION.format("a *= 5;"))
        tree = self.parser.parse(FB_WITH_DECLARATION.format("a /= 5;"))
        tree = self.parser.parse(FB_WITH_DECLARATION.format("a %= 5;"))

    def test_assert_statement(self):
        """Check boolean expressions valid in any context"""
        tree = self.parser.parse(
            FUNCTION_BLOCK.format("assert(a >= 0 && a <= 15);"))

    def test_template_global(self):
        """Checks correct parsing of globals with template types"""
        self.parser.parse(TEMPLATE_GLOBAL)

        # Function A contains a non-sharded global of type Array<Address>
        addresses = self.parser.used_globals_to_addresses('A', [])
        self.assertEqual(addresses, ['users'])

        # Function B contains a sharded global of type Array<Address>
        addresses = self.parser.used_globals_to_addresses('B', [])
        self.assertEqual(addresses, ['sharded_users.abc'])

    def test_if_blocks(self):
        """Checks correct parsing of if blocks"""
        # Partial contract text with function block and variable instantiation
        PARTIAL_BLOCK = FUNCTION_BLOCK.format("""
        var a: Int64 = 5;
        var b: Int64 = 0;
        {}""")

        # Simple if block
        tree = self.parser.parse(
            PARTIAL_BLOCK.format("""
        if (a > 5)
            b = 6;
        endif"""))
        self.assertTrue(tree is not False)

        # If-else block
        tree = self.parser.parse(
            PARTIAL_BLOCK.format("""
        if (a > 5)
            b = 6;
        else
            b = 7;
        endif"""))
        self.assertTrue(tree is not False)

        # Nested if-else-if block
        tree = self.parser.parse(
            PARTIAL_BLOCK.format("""
        if (a > 5)
            b = 6;
        else if (a < 5)
                b = 4;
            endif
        endif"""))
        self.assertTrue(tree is not False)

        # If-elseif block
        tree = self.parser.parse(
            PARTIAL_BLOCK.format("""
        if (a > 5)
            b = 6;
        elseif (a < 5)
            b = 4;
        endif"""))
        self.assertTrue(tree is not False)

        # Complex example
        tree = self.parser.parse(
            PARTIAL_BLOCK.format("""
        if (a > 5 && a < 100)
            b = 6;
        elseif (a < 2 || a > 100)
            if (a < 0)
                b = 4;
            else
                b = 2;
            endif
        else
            b = 3;
        endif"""))
        self.assertTrue(tree is not False)

    def test_warn_on_parse_fail(self):
        with patch('logging.warning') as mock_warn:
            tree = self.parser.parse("This code is not valid")
            self.assertFalse(tree)
            self.assertEqual(mock_warn.call_count, 2)
Esempio n. 3
0
class Contract:
    def __init__(self, source: str, owner: AddressLike, nonce: bytes = None):
        self._source = str(source)
        self._digest = _compute_digest(self._source)
        self._owner = Address(owner)
        self._nonce = bytes(urandom(8)) if nonce is None else nonce

        hasher = hashlib.sha256()
        hasher.update(bytes(self._owner))
        hasher.update(self._nonce)

        self._address = Address(hasher.digest())

        # Etch parser for analysing contract
        self._parser = EtchParser(self._source)

        # Generate set of action and query entry points
        entries = self._parser.entry_points(['init', 'action', 'query'])
        self._actions = list(set(entries.get('action', [])))
        self._queries = list(set(entries.get('query', [])))

        init = entries.get('init', [])
        if len(init) > 1:
            raise RuntimeError(
                'Contract may not have more than one @init function, found: {}'
                .format(', '.join(init)))
        self._init = init[0] if len(init) else None

    @property
    def name(self):
        return '{}.{}'.format(self.digest.to_hex(), self.address)

    def dumps(self):
        return json.dumps(self._to_json_object())

    def dump(self, fp):
        return json.dump(self._to_json_object(), fp)

    @classmethod
    def loads(cls, s):
        return cls._from_json_object(json.loads(s))

    @classmethod
    def load(cls, fp):
        return cls._from_json_object(json.load(fp))

    @property
    def owner(self):
        return self._owner

    @owner.setter
    def owner(self, owner):
        self._owner = Address(owner)

    @property
    def source(self):
        return self._source

    @property
    def digest(self):
        return self._digest

    @property
    def nonce(self) -> str:
        return base64.b64encode(self._nonce).decode()

    @property
    def nonce_bytes(self) -> bytes:
        return self._nonce

    @property
    def address(self) -> Address:
        return self._address

    @property
    def encoded_source(self):
        return base64.b64encode(self.source.encode('ascii')).decode()

    def create(self, api: ContractsApiLike, owner: Entity, fee: int):
        # Set contract owner (required for resource prefix)
        self.owner = owner

        if self._init is None:
            raise RuntimeError("Contract has no initialisation function")

        # Generate resource addresses used by persistent globals
        try:
            resource_addresses = [
                'fetch.contract.state.{}'.format(self.digest.to_hex())
            ]
            resource_addresses.extend(
                ShardMask.state_to_address(address, self)
                for address in self._parser.used_globals_to_addresses(
                    self._init, [self._owner]))
        except (UnparsableAddress, UseWildcardShardMask):
            logging.warning(
                "Couldn't auto-detect used shards, using wildcard shard mask")
            shard_mask = BitVector()
        else:
            # Generate shard mask from resource addresses
            shard_mask = ShardMask.resources_to_shard_mask(
                resource_addresses, api.server.num_lanes())

        return self._api(api).create(owner, self, fee, shard_mask=shard_mask)

    def query(self, api: ContractsApiLike, name: str, **kwargs):
        if self._owner is None:
            raise RuntimeError(
                'Contract has no owner, unable to perform any queries. Did you deploy it?'
            )

        if name not in self._queries:
            raise RuntimeError(
                '{} is not an valid query name. Valid options are: {}'.format(
                    name, ','.join(list(self._queries))))

        # make the required query on the API
        success, response = self._api(api).query(self._digest, self.address,
                                                 name, **kwargs)

        if not success:
            if response is not None and "msg" in response:
                raise RuntimeError('Failed to make requested query: ' +
                                   response["msg"])
            else:
                raise RuntimeError(
                    'Failed to make requested query with no error message.')

        return response['result']

    def action(self, api: ContractsApiLike, name: str, fee: int,
               signers: List[Entity], *args):
        if self._owner is None:
            raise RuntimeError(
                'Contract has no owner, unable to perform any actions. Did you deploy it?'
            )

        if name not in self._actions:
            raise RuntimeError(
                '{} is not an valid action name. Valid options are: {}'.format(
                    name, ','.join(list(self._actions))))

        try:
            # Generate resource addresses used by persistent globals
            resource_addresses = [
                ShardMask.state_to_address(address, self)
                for address in self._parser.used_globals_to_addresses(
                    name, list(args))
            ]
        except (UnparsableAddress, UseWildcardShardMask):
            logging.warning(
                "Couldn't auto-detect used shards, using wildcard shard mask")
            shard_mask = BitVector()
        else:
            # Generate shard mask from resource addresses
            shard_mask = ShardMask.resources_to_shard_mask(
                resource_addresses, api.server.num_lanes())

        return self._api(api).action(self._digest,
                                     self.address,
                                     name,
                                     fee,
                                     self.owner,
                                     signers,
                                     *args,
                                     shard_mask=shard_mask)

    @staticmethod
    def _api(api: ContractsApiLike):
        if isinstance(api, ContractsApi):
            return api
        elif isinstance(api, LedgerApi):
            return api.contracts
        else:
            assert False

    @staticmethod
    def _from_json_object(obj):
        assert obj['version'] == 1

        source = base64.b64decode(obj['source']).decode()
        owner = obj['owner']
        nonce = base64.b64decode(obj['nonce'].encode())

        sc = Contract(source, owner, nonce)

        return sc

    def _to_json_object(self):
        return {
            'version': 1,
            'nonce': self.nonce,
            'owner': None if self._owner is None else str(self._owner),
            'source': self.encoded_source
        }
Esempio n. 4
0
class ShardMaskParsingTests(unittest.TestCase):
    def setUp(self) -> None:
        try:
            self.parser = EtchParser()
            self.assertIsNone(self.parser._parsed_tree,
                              "Unexpected initialisation of parsed tree")
        except ParseError as e:
            self.fail("Parser isntantiation failed with: \n" + str(e))

    def test_outside_annotation(self):
        """Test handling of calls to subfunctions containing use statements"""
        self.parser.parse(NON_ENTRY_GLOBAL)

        # Detect call to non-entry function that uses globals
        with self.assertRaises(UnparsableAddress):
            glob_used = self.parser.globals_used('setup', ['abc'])
            self.assertEqual(
                len(glob_used), 0,
                "Unexpected used globals found when declared in non annotated function"
            )

        # Test transfer function, which calls a non-global-using subfunction
        glob_used = self.parser.globals_used('transfer', ['abc', 'def', 100])
        self.assertEqual(
            '{}.{}'.format(glob_used[0][0], glob_used[0][1].value),
            'balance_state.abc')
        self.assertEqual(
            '{}.{}'.format(glob_used[1][0], glob_used[1][1].value),
            'balance_state.def')

    def test_global_using_subfunctions(self):
        """Test detection of non-annotated functions containing 'use' statements"""
        self.parser.parse(NON_ENTRY_GLOBAL)

        # List of non-annotated functions that use globals
        global_using_subfunctions = self.parser.global_using_subfunctions()
        self.assertIn('set_balance', global_using_subfunctions)
        self.assertNotIn('sub', global_using_subfunctions)

        # Test that wildcard used when annotated function calls global using subfunction
        with self.assertRaises(UnparsableAddress):
            self.parser.used_globals_to_addresses('setup', ['abc'])

        # Parsing of function that doesn't call global-using-subfunction should succeed
        glob_addresses = self.parser.used_globals_to_addresses(
            'transfer', ['abc', 'def', 100])
        self.assertEqual(glob_addresses,
                         ['balance_state.abc', 'balance_state.def'])

    def test_use_any(self):
        """Test correct handling of 'use any'"""
        self.parser.parse(USE_ANY_NON_SHARDED)

        # Test correct detection of all persistent globals when none are sharded
        used_globals = self.parser.globals_used('swap', [])
        self.assertEqual(set(used_globals), {'var1', 'var2'})

        # Test correct raising of wildcard-needed exception if any globals are sharded
        self.parser.parse(USE_ANY_SHARDED)
        with self.assertRaises(UseWildcardShardMask):
            used_globals = self.parser.globals_used('swap', [])

    def test_unparsable(self):
        """Test that parser raises an exception when parsing fails"""
        with patch('logging.warning') as mock_warn:
            self.parser.parse("This is not valid etch code")
            self.assertEqual(mock_warn.call_count, 2)

        with self.assertRaises(EtchParserError):
            used_globals = self.parser.globals_used('entry', [])