# Copyright (c) 2014 Kontron Europe GmbH # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA from array import array from . import constants from ..utils import ByteBuffer from ..errors import (CompletionCodeError, EncodingError, DecodingError, DescriptionError) class BaseField(object): def __init__(self, name, length, default=None): self.name = name self.length = length self.default = default def decode(self, obj, data): raise NotImplementedError() def encode(self, obj, data): if getattr(obj, self.name) is None: raise EncodingError('Field "%s" not set.' % self.name) raise NotImplementedError() def create(self): raise NotImplementedError() class ByteArray(BaseField): def __init__(self, name, length, default=None): BaseField.__init__(self, name, length) if default is not None: self.default = array('B', default) else: self.default = None def _length(self, obj): return self.length def encode(self, obj, data): a = getattr(obj, self.name) if len(a) != self._length(obj): raise EncodingError('Array must be exaclty %d bytes long ' '(but is %d long)' % (self._length(obj), len(a))) for i in range(self._length(obj)): data.push_unsigned_int(a[i], 1) def decode(self, obj, data): bytes = [] for i in range(self._length(obj)): bytes.append(data.pop_unsigned_int(1)) setattr(obj, self.name, array('B', bytes)) def create(self): if self.default is not None: return array('B', self.default) else: return array('B', self.length * b'\x00') class VariableByteArray(ByteArray): """Array of bytes with variable length. The length is dynamically computed by a function. """ def __init__(self, name, length_func): ByteArray.__init__(self, name, None, None) self._length_func = length_func def _length(self, obj): return self._length_func(obj) def create(self): return None class UnsignedInt(BaseField): def encode(self, obj, data): value = getattr(obj, self.name) data.push_unsigned_int(value, self.length) def decode(self, obj, data): value = data.pop_unsigned_int(self.length) setattr(obj, self.name, value) def create(self): if self.default is not None: return self.default else: return 0 class String(BaseField): def encode(self, obj, data): value = getattr(obj, self.name) data.push_string(value) def decode(self, obj, data): value = data.pop_string(self.length) setattr(obj, self.name, value) def create(self): if self.default is not None: return self.default else: return '' class CompletionCode(UnsignedInt): def __init__(self, name='completion_code'): UnsignedInt.__init__(self, name, 1, None) def decode(self, obj, data): UnsignedInt.decode(self, obj, data) cc = getattr(obj, self.name) if cc != constants.CC_OK: raise CompletionCodeError(cc) class UnsignedIntMask(UnsignedInt): def __init__(self, name, length, mask, default=None): UnsignedInt.__init__(self, name, length, default) class Timestamp(UnsignedInt): def __init__(self, name): UnsignedInt.__init__(self, name, 4, None) class Conditional(object): def __init__(self, cond_fn, field): self._condition_fn = cond_fn self._field = field def __getattr__(self, name): return getattr(self._field, name) def encode(self, obj, data): if self._condition_fn(obj): self._field.encode(obj, data) def decode(self, obj, data): if self._condition_fn(obj): self._field.decode(obj, data) def create(self): return self._field.create() class Optional(object): def __init__(self, field): self._field = field def __getattr__(self, name): return getattr(self._field, name) def decode(self, obj, data): if len(data) > 0: self._field.decode(obj, data) else: setattr(obj, self._field.name, None) def encode(self, obj, data): if getattr(obj, self._field.name) is not None: self._field.encode(obj, data) def create(self): return None class RemainingBytes(BaseField): def __init__(self, name): BaseField.__init__(self, name, None) def encode(self, obj, data): a = getattr(obj, self.name) data.extend(a) def decode(self, obj, data): setattr(obj, self.name, array('B', data[:])) del data.array[:] def create(self): return array('B') class Bitfield(BaseField): class Bit(object): def __init__(self, name, width=1, default=None): self.name = name self._width = width self.default = default class ReservedBit(Bit): counter = 0 def __init__(self, width, default=0): Bitfield.Bit.__init__(self, 'reserved_bit_%d' % Bitfield.reserved_bit_counter, width, default) Bitfield.reserved_bit_counter += 1 class BitWrapper(object): def __init__(self, bits, length): self._bits = bits self._length = length for bit in bits: if hasattr(self, bit.name): raise DescriptionError('Bit with name "%s" already added' % bit.name) if bit.default is not None: setattr(self, bit.name, bit.default) else: setattr(self, bit.name, 0) def __str__(self): s = '[' for attr in dir(self): if attr.startswith('_'): continue s += '%s=%s, ' % (attr, getattr(self, attr)) s += ']' return s def __int__(self): return self._value def _get_value(self): value = 0 for bit in self._bits: bit_value = getattr(self, bit.name) if bit_value is None: bit_value = bit.default if bit_value is None: raise EncodingError('Bitfield "%s" not set.' % bit.name) value |= (bit_value & (2**bit._width - 1)) << bit.offset return value def _set_value(self, value): for bit in self._bits: tmp = (value >> bit.offset) & (2**bit._width - 1) setattr(self, bit.name, tmp) _value = property(_get_value, _set_value) reserved_bit_counter = 0 def __init__(self, name, length, *bits): BaseField.__init__(self, name, length) self._bits = bits self._precalc_offsets() def _precalc_offsets(self): offset = 0 for b in self._bits: b.offset = offset offset += b._width if offset != 8 * self.length: raise DescriptionError('Bit description does not match bitfield ' 'length') def encode(self, obj, data): wrapper = getattr(obj, self.name) value = wrapper._value for i in range(self.length): data.push_unsigned_int((value >> (8*i)) & 0xff, 1) def decode(self, obj, data): value = 0 for i in range(self.length): try: value |= data.pop_unsigned_int(1) << (8*i) except IndexError: raise DecodingError('Data too short for message') wrapper = getattr(obj, self.name) wrapper._value = value def create(self): return Bitfield.BitWrapper(self._bits, self.length) class GroupExtensionIdentifier(UnsignedInt): def __init__(self, name='picmg_identifier', value=None): UnsignedInt.__init__(self, name, 1, value) class EventMessageRevision(UnsignedInt): def __init__(self, value=None): UnsignedInt.__init__(self, 'event_message_rev', 1, value) class Message(object): RESERVED_FIELD_NAMES = ['cmdid', 'netfn', 'lun', 'group_extension'] __default_lun__ = 0 __group_extension__ = None __not_implemented__ = False def __init__(self, *args, **kwargs): """Message constructor with ([buf], [field=val,...]) prototype. Arguments: buf -- option message buffer to decode Optional keyword arguments corresponts to members to set (matching fields in self.__fields__, or 'data'). """ # create message fields if hasattr(self, '__fields__'): self._create_fields() # set default lun self.lun = self.__default_lun__ self.data = '' if args: self._decode(args[0]) else: for (name, value) in kwargs.items(): self._set_field(name, value) def _set_field(self, name, value): raise NotImplementedError() # TODO walk along the properties.. def __str__(self): return '{} [netfn={}, cmd={}, grp={}]'.format(type(self).__name__, self.netfn, self.cmdid, self.group_extension) def _create_fields(self): for field in self.__fields__: if field.name in self.RESERVED_FIELD_NAMES: raise DescriptionError('Field name "%s" is reserved' % field.name) if hasattr(self, field.name): raise DescriptionError('Field "%s" already added', field.name) setattr(self, field.name, field.create()) def _pack(self): """Pack the message and return an array.""" data = ByteBuffer() if not hasattr(self, '__fields__'): return data.array for field in self.__fields__: field.encode(self, data) return data.array def _encode(self): """Encode the message and return a bytestring.""" data = ByteBuffer() if not hasattr(self, '__fields__'): return data.tostring() for field in self.__fields__: field.encode(self, data) return data.tostring() def _decode(self, data): """Decode the bytestring message.""" if not hasattr(self, '__fields__'): return data = ByteBuffer(data) cc = None for field in self.__fields__: try: field.decode(self, data) except CompletionCodeError as e: # stop decoding on completion code != 0 cc = e.cc break if (cc is None or cc == 0) and len(data) > 0: raise DecodingError('Data has extra bytes') def _is_request(self): return self.__netfn__ & 1 == 0 def _is_response(self): return self.__netfn__ & 1 == 1 netfn = property(lambda s: s.__netfn__) cmdid = property(lambda s: s.__cmdid__) group_extension = property(lambda s: s.__group_extension__) def encode_message(msg): return msg._encode() def decode_message(msg, data): return msg._decode(data) def pack_message(msg): return msg._pack()