def test_in_and_out_parameters(self): # Test what happens when we have a parameter that is to be passed out # of an entry, but also into a child (issue122). # # ___ e ___ # __c__ d(len=a) # a b(len=a) a = fld.Field('a', length=8) b = fld.Field('b', length=expr.compile('${a}')) c = seq.Sequence('c', [a, b]) d = fld.Field('d', length=expr.compile('${c.a}')) e = seq.Sequence('e', [c, d]) lookup = prm.ExpressionParameters([e]) self.assertEqual([], lookup.get_params(e)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer())], lookup.get_params(a)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer())], lookup.get_params(c)) self.assertEqual([prm.Param('a', prm.Param.IN, _Integer())], lookup.get_params(b)) self.assertEqual([prm.Param('c.a', prm.Param.IN, _Integer())], lookup.get_params(d)) self.assertEqual([prm.Local('c.a', _Integer())], lookup.get_locals(e)) self.assertEqual([], lookup.get_locals(c)) self.assertTrue(lookup.is_value_referenced(a)) self.assertEqual([prm.Param('a', prm.Param.IN, _Integer())], list(lookup.get_passed_variables(c, c.children[1]))) self.assertEqual([prm.Param('c.a', prm.Param.IN, _Integer())], list(lookup.get_passed_variables(e, e.children[1])))
def test_reference_choice(self): # Test that we can correctly reference a choice entry, where each of # its children have value types. byte = seq.Sequence( "8 bit:", [fld.Field("id", 8, constraints=[Equals(dt.Data("\x00"))]), fld.Field("length", 8)], value=expr.compile("${length}"), ) word = seq.Sequence( "16 bit:", [fld.Field("id", 8, constraints=[Equals(dt.Data("\x01"))]), fld.Field("length", 16)], value=expr.compile("${length}"), ) length = chc.Choice("variable integer", [byte, word]) data = fld.Field("data", expr.compile("${variable integer}"), fld.Field.TEXT) spec = seq.Sequence("spec", [length, data]) results = dict( (entry, value) for is_starting, name, entry, entry_data, value in spec.decode(dt.Data("\x00\x20abcde")) if not is_starting ) self.assertEqual("abcd", results[data]) results = dict( (entry, value) for is_starting, name, entry, entry_data, value in spec.decode(dt.Data("\x01\x00\x20abcde")) if not is_starting ) self.assertEqual("abcd", results[data])
def test_length_and_value_reference(self): # Test a length reference and a value reference to the same entry. a = fld.Field('a', length=8) c = fld.Field('c', length=expr.compile('len{a}')) d = fld.Field('d', length=expr.compile('${a}')) b = seq.Sequence('b', [a, c, d]) # Lets just try a quick decode to make sure we've specified it ok... #list(b.decode(dt.Data('\x08cd'))) # Now test the parameters being passed around. lookup = prm.ExpressionParameters([b]) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer()), prm.Param('a length', prm.Param.OUT, _Integer())], lookup.get_params(a)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer()), prm.Param('a length', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(b, b.children[0]))) self.assertEqual([prm.Param('a length', prm.Param.IN, _Integer())], list(lookup.get_passed_variables(b, b.children[1]))) self.assertEqual([prm.Param('a length', prm.Param.IN, _Integer())], lookup.get_params(c)) self.assertEqual([prm.Local('a', _Integer()), prm.Local('a length', _Integer())], lookup.get_locals(b)) self.assertTrue(lookup.is_length_referenced(a))
def test_sequence_with_referenced_value(self): a = fld.Field('a', length=8) b = seq.Sequence('b', [ent.Child('b:', a)], value=expr.compile('${b:}')) c = fld.Field('c', length=expr.compile('${b} * 8')) d = seq.Sequence('d', [a, b, c]) lookup = prm.ExpressionParameters([a, d]) self.assertEqual([prm.Local('b:', _Integer())], lookup.get_locals(b)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer())], lookup.get_params(a)) self.assertEqual([prm.Param('b', prm.Param.OUT, _Integer())], lookup.get_params(b)) self.assertEqual([prm.Param('b', prm.Param.IN, _Integer())], lookup.get_params(c)) self.assertEqual([], lookup.get_params(d))
def test_renamed_common_reference(self): text_digit = fld.Field('text digit', 8, constraints=[Minimum(48), Maximum(58)]) digit = seq.Sequence('digit', [text_digit], value=expr.compile("${text digit} - 48")) b = seq.Sequence('b', [ ent.Child('length', digit), fld.Field('data', length=expr.compile("${length} * 8"))]) lookup = prm.ExpressionParameters([b]) self.assertEqual([], lookup.get_params(b)) self.assertEqual([prm.Param('digit', prm.Param.OUT, _Integer())], lookup.get_params(digit)) self.assertEqual([prm.Param('length', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(b, b.children[0])))
def test_sequence_expected_value(self): a = seq.Sequence('a', [fld.Field('b', 8), fld.Field('c', 8)], value=expr.compile('${b} + ${c}')) a.constraints.append(Equals(7)) list(a.decode(dt.Data('\x03\x04'))) list(a.decode(dt.Data('\x06\x01'))) self.assertRaises(ConstraintError, list, a.decode(dt.Data('\x05\x01'))) self.assertRaises(ConstraintError, list, a.decode(dt.Data('\x07\x01')))
def signed_litte_endian(self, length_expr, attributes=None): try: length = length_expr.evaluate({}) except UndecodedReferenceError: return self._variable_length_signed_little_endian(length_expr, attributes=attributes) if length % 8 != 0: raise IntegerError('The length of little endian fields must be a multiple of 8.') name = 'little endian integer %s' % length try: result = self.common[name] except KeyError: children = [] num_bytes = length / 8 for i in range(num_bytes - 1): children.append(Field('byte %i:' % i, 8, attributes=attributes)) children.append(Field('signed:', 1, attributes=attributes)) children.append(Field('byte %i:' % (num_bytes - 1), 7, attributes=attributes)) # We define the minimum as being '-number - 1' to avoid compiler # warnings in C, where there are no negative constants, just # constants that are then negated (and the positive version of # the constant may be too big a number). maximum = 1 << (length - 1) names = ['${byte %i:}' % i for i in range(num_bytes)] names.reverse() reference = reduce(lambda left, right: '(%s) * 256 + %s' % (left, right), names) value_text = '${signed:} * (0 - %i - 1) + %s' % (maximum - 1, reference) result = Sequence(name, children, value=compile(value_text), attributes=attributes) self.common[name] = result return result
def test_referenced_field_length(self): a = Field('a', length=4) b = Sequence('b', [], value=compile('len{a} * 8 + 4')) c = Sequence('c', [a, b]) params = ExpressionParameters([c]) range = EntryValueType(b).range(params) self.assertEqual(36, range.min) self.assertEqual(36, range.max)
def test_unused_parameters_with_same_name(self): # Test that when we have multiple 'unused' parameters with the same # name we don't duplicate the same local variable. This happened with # the vfat specification (the different bootsector types all had the # same output parameters). a1 = fld.Field('a', length=16) a2 = fld.Field('a', length=8) # C doesn't use the outputs from a1 and a2, so should have a single # local variable. c = seq.Sequence('c', [a1, a2]) # Now create a couple of other entries that actually use a1 & a2 d1 = seq.Sequence('d1', [a1, fld.Field('e1', length=expr.compile('${a}'))]) d2 = seq.Sequence('d2', [a2, fld.Field('e2', length=expr.compile('${a}'))]) lookup = prm.ExpressionParameters([a1, a2, c, d1, d2]) self.assertEqual([prm.Param('unused a', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(c, c.children[0]))) self.assertEqual([prm.Param('unused a', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(c, c.children[1]))) self.assertEqual([prm.Local('unused a', _Integer())], lookup.get_locals(c))
def test_choice_reference(self): # Test that we can correctly reference a choice (which in effect # references each of its children). # # We test this by creating a choice where each child has a value type, # and attempt to reference the top level choice. len = fld.Field('len', length=8) a = fld.Field('a', length=32) b = fld.Field('b', length=16) c = fld.Field('c', length=8) var_len = chc.Choice('var_len', [a, b, c], length=expr.compile('${len}')) data = fld.Field('data', length=expr.compile('${var_len}')) spec = seq.Sequence('spec', [len, var_len, data]) # Check the parameters passed in and out of each entry lookup = prm.ExpressionParameters([spec]) self.assertEqual([], lookup.get_params(spec)) self.assertEqual([prm.Param('len', prm.Param.OUT, _Integer())], lookup.get_params(len)) self.assertEqual([prm.Param('len', prm.Param.IN, _Integer()), prm.Param('var_len', prm.Param.OUT, _Integer())], lookup.get_params(var_len)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer())], lookup.get_params(a)) self.assertEqual([prm.Param('b', prm.Param.OUT, _Integer())], lookup.get_params(b)) self.assertEqual([prm.Param('var_len', prm.Param.IN, _Integer())], lookup.get_params(data)) # Test the mapping of the parameters for the choice to the option # entries. self.assertEqual([prm.Param('var_len', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(var_len, var_len.children[0]))) self.assertEqual([prm.Param('var_len', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(var_len, var_len.children[1]))) self.assertEqual([prm.Param('var_len', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(var_len, var_len.children[2]))) # And validate the locals... self.assertEqual([prm.Local('len', _Integer()), prm.Local('var_len', _Integer())], lookup.get_locals(spec)) self.assertEqual([], lookup.get_locals(len)) self.assertEqual([], lookup.get_locals(var_len)) self.assertEqual([], lookup.get_locals(a)) self.assertEqual([], lookup.get_locals(data))
def test_sequence_value(self): a = fld.Field('a:', 8, fld.Field.INTEGER) b = seq.Sequence('b', [a], expr.compile('${a:}')) lookup = prm.ResultParameters([b]) self.assertEqual([], lookup.get_params(a)) self.assertEqual([prm.Param('result', prm.Param.OUT, EntryType(b))], lookup.get_params(b)) self.assertEqual([], lookup.get_locals(a)) self.assertEqual([], lookup.get_locals(b)) self.assertEqual([], lookup.get_passed_variables(b, b.children[0]))
def test_name_ends_in_length(self): a = fld.Field('data length', 8, fld.Field.INTEGER) b = fld.Field('data', expr.compile('${data length} * 8')) c = seq.Sequence('c', [a, b]) params = prm.ExpressionParameters([c]) self.assertEqual([prm.Local('data length', _Integer())], params.get_locals(c)) self.assertEqual([], params.get_params(c)) self.assertEqual([prm.Param('data length', prm.Param.OUT, _Integer())], params.get_params(a)) self.assertEqual([prm.Param('data length', prm.Param.IN, _Integer())], params.get_params(b)) self.assertEqual(True, params.is_value_referenced(a)) self.assertEqual(False, params.is_length_referenced(a))
def test_reference_choice(self): # Test that we can correctly reference a choice entry, where each of # its children have value types. byte = seq.Sequence('8 bit:', [ fld.Field('id', 8, constraints=[Equals(dt.Data('\x00'))]), fld.Field('length', 8)], value=expr.compile('${length}')) word = seq.Sequence('16 bit:', [ fld.Field('id', 8, constraints=[Equals(dt.Data('\x01'))]), fld.Field('length', 16)], value=expr.compile('${length}')) length = chc.Choice('variable integer', [byte, word]) data = fld.Field('data', expr.compile('${variable integer}'), fld.Field.TEXT) spec = seq.Sequence('spec', [length, data]) results = dict((entry, value)for is_starting, name, entry, entry_data, value in spec.decode(dt.Data('\x00\x20abcde')) if not is_starting) self.assertEqual('abcd', results[data]) results = dict((entry, value)for is_starting, name, entry, entry_data, value in spec.decode(dt.Data('\x01\x00\x20abcde')) if not is_starting) self.assertEqual('abcd', results[data])
def test_expression(self): a = seq.Sequence('a', [ fld.Field('b', format=fld.Field.INTEGER, length=8), fld.Field('c', length=expr.compile('${a} * 8'))]) expected = """<protocol> <sequence name="a"> <field name="b" length="8" type="integer" /> <field name="c" length="(${a} * 8)" /> </sequence> </protocol>""" assert_xml_equivalent(expected, xml.save(a))
def test_common_entry_with_input_parameter(self): # Test that we correctly resolve a common entry that has an input # parameter that resolves to mulitiple (different) entries. a = fld.Field('a', length=expr.compile('${b}')) # Here the common entry 'a' is used into two locations, each time it # resolves to an entry with a different length. c = seq.Sequence('c', [fld.Field('b', 8), a]) d = seq.Sequence('d', [fld.Field('b', 16), a]) lookup = prm.ExpressionParameters([a, c, d]) self.assertEqual([prm.Param('b', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(c, c.children[0]))) self.assertEqual([prm.Param('b', prm.Param.OUT, _Integer())], list(lookup.get_passed_variables(d, d.children[0])))
def signed_big_endian(self, length_expr, attributes=None): try: # We try to choose the name based on the length value. If we # cannot evaluate the length, we'll use the length name. name = 'big endian integer %s' % length_expr.evaluate({}) except UndecodedReferenceError: name = 'big endian integer %s' % length_expr try: result = self.common[name] except KeyError: is_signed = Field('signed:', 1, attributes=attributes) value = Field('integer value:', ArithmeticExpression(operator.sub, length_expr, Constant(1)), attributes=attributes) expression = compile('${signed:} * ((0 - 1) << (%s - 1)) + ${integer value:}' % (length_expr)) result = Sequence(name, [is_signed, value], value=expression, attributes=attributes) self.common[name] = result return result
def _variable_length_signed_little_endian(self, length_expr, attributes=None): name = 'variable length integer' try: result = self.common[name] except KeyError: options = [] for length in (8, 16, 24, 32, 64): option = self.signed_litte_endian(Constant(length)) options.append(Sequence('variable %s' % option.name, [Child('value', option)], value=length_expr, constraints=[Equals(length)])) # We wrap the choice inside a sequence, as choices don't currently # 'compile' to integers (and sequences do). var_name = 'variable integer types:' result = Sequence(name, [Choice(var_name, options, attributes=attributes)], value=compile('${%s}' % var_name), attributes=attributes) self.common[name] = result return result
def test_param_ordering(self): # Test that we order the parameters consistently a = fld.Field('a', 8) b = fld.Field('b', 8) c = fld.Field('c', 8) d = seq.Sequence('d', [a,b,c]) e = fld.Field('e', expr.compile('${d.a} + ${d.b} + ${d.c}')) f = seq.Sequence('f', [d, e]) params = prm.ExpressionParameters([f]) self.assertEqual([prm.Local('d.a', _Integer()), prm.Local('d.b', _Integer()), prm.Local('d.c', _Integer())], params.get_locals(f)) self.assertEqual([prm.Param('a', prm.Param.OUT, _Integer()), prm.Param('b', prm.Param.OUT, _Integer()), prm.Param('c', prm.Param.OUT, _Integer())], params.get_params(d)) self.assertEqual([prm.Param('d.a', prm.Param.OUT, _Integer()), prm.Param('d.b', prm.Param.OUT, _Integer()), prm.Param('d.c', prm.Param.OUT, _Integer())], list(params.get_passed_variables(f, f.children[0])))
def test_sequence_with_expected_value(self): a = seq.Sequence('a', [], value=expr.compile('1'), constraints=[Equals(1)]) text = xml.to_string(a, dt.Data()) self.assertEqual('<a></a>\n', text)
def test_sequence_with_children_and_value(self): a = seq.Sequence('a', [fld.Field('b', length=8, format=fld.Field.INTEGER)], value=expr.compile('11')) text = xml.to_string(a, dt.Data('\xff')) self.assertEqual('<a>\n <b>255</b>\n 11\n</a>\n', text)
def test_different_child_name(self): digit = fld.Field('digit:', length=8) number = seq.Sequence('number', [digit], value=expr.compile("${digit:} - 48") ) header = seq.Sequence('header', [ent.Child('length', number), fld.Field('data', length=expr.compile('${length} * 8'), format=fld.Field.TEXT)]) text = xml.to_string(header, dt.Data('5abcde')) self.assertEqual('<header>\n <length>5</length>\n <data>abcde</data>\n</header>\n', text)
def test_multiple_range(self): a = compile('8 * 1 * 4') self.assertEqual(32, expression_range(a).min) self.assertEqual(32, expression_range(a).max)
def _parse_expression(self, text): try: return exp.compile(text) except exp.ExpressionError, ex: raise XmlExpressionError(ex, self._filename, self.locator)
def test_divide_range(self): a = compile('16 / 2 / 4') self.assertEqual(2, expression_range(a).min) self.assertEqual(2, expression_range(a).max)
def test_mod_range(self): a = compile('100 % 2') self.assertEqual(Range(0, 1), expression_range(a))
def test_add_range(self): a = compile('(10 + 3) + 7') self.assertEqual(20, expression_range(a).min) self.assertEqual(20, expression_range(a).max)
def test_subtract_range(self): a = compile('95 - (100 - 20)') self.assertEqual(15, expression_range(a).min) self.assertEqual(15, expression_range(a).max)
def test_exception_when_referencing_text_field(self): a = seq.Sequence('a', [fld.Field('b', length=8, format=fld.Field.TEXT), fld.Field('c', length=expr.compile('${b}'))]) self.assertRaises(prm.BadReferenceTypeError, prm.ExpressionParameters, [a])
def test_referencing_sequence_without_value(self): a = seq.Sequence('a', []) b = fld.Field('b', expr.compile('${a}'), fld.Field.INTEGER) c = seq.Sequence('c', [a, b]) self.assertRaises(prm.BadReferenceError, prm.ExpressionParameters, [c])
def test_constant_range(self): a = compile('8') self.assertEqual(8, expression_range(a).min) self.assertEqual(8, expression_range(a).max)