Add URH
This commit is contained in:
		
							
								
								
									
										65
									
								
								Software/urh/tests/awre/AWRETestCase.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								Software/urh/tests/awre/AWRETestCase.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import numpy
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
 | 
			
		||||
from tests.utils_testing import get_path_for_data_file
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
 | 
			
		||||
from urh.signalprocessing.MessageType import MessageType
 | 
			
		||||
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AWRETestCase(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        numpy.set_printoptions(linewidth=80)
 | 
			
		||||
        self.field_types = self.__init_field_types()
 | 
			
		||||
 | 
			
		||||
    def get_format_finder_from_protocol_file(self, filename: str, clear_participant_addresses=True, return_messages=False):
 | 
			
		||||
        proto_file = get_path_for_data_file(filename)
 | 
			
		||||
        protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
 | 
			
		||||
        protocol.from_xml_file(filename=proto_file, read_bits=True)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        if clear_participant_addresses:
 | 
			
		||||
            ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        if return_messages:
 | 
			
		||||
            return ff, protocol.messages
 | 
			
		||||
        else:
 | 
			
		||||
            return ff
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __init_field_types():
 | 
			
		||||
        result = []
 | 
			
		||||
        for field_type_function in FieldType.Function:
 | 
			
		||||
            result.append(FieldType(field_type_function.value, field_type_function))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def clear_message_types(messages: list):
 | 
			
		||||
        mt = MessageType("empty")
 | 
			
		||||
        for msg in messages:
 | 
			
		||||
            msg.message_type = mt
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def save_protocol(name, protocol_generator, silent=False):
 | 
			
		||||
        filename = os.path.join(tempfile.gettempdir(), name + ".proto")
 | 
			
		||||
        if isinstance(protocol_generator, ProtocolGenerator):
 | 
			
		||||
            protocol_generator.to_file(filename)
 | 
			
		||||
        elif isinstance(protocol_generator, ProtocolAnalyzer):
 | 
			
		||||
            participants = list(set(msg.participant for msg in protocol_generator.messages))
 | 
			
		||||
            protocol_generator.to_xml_file(filename, [], participants=participants, write_bits=True)
 | 
			
		||||
        info = "Protocol written to " + filename
 | 
			
		||||
        if not silent:
 | 
			
		||||
            print()
 | 
			
		||||
            print("-" * len(info))
 | 
			
		||||
            print(info)
 | 
			
		||||
            print("-" * len(info))
 | 
			
		||||
							
								
								
									
										791
									
								
								Software/urh/tests/awre/AWRExperiments.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										791
									
								
								Software/urh/tests/awre/AWRExperiments.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,791 @@
 | 
			
		||||
import array
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import time
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from tests.utils_testing import get_path_for_data_file
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.Preprocessor import Preprocessor
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.awre.engines.Engine import Engine
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Message import Message
 | 
			
		||||
from urh.signalprocessing.MessageType import MessageType
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
from urh.util.GenericCRC import GenericCRC
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_for_num_broken(protocol_nr, num_broken: list, num_messages: int, num_runs: int) -> list:
 | 
			
		||||
    random.seed(0)
 | 
			
		||||
    np.random.seed(0)
 | 
			
		||||
 | 
			
		||||
    result = []
 | 
			
		||||
    for broken in num_broken:
 | 
			
		||||
        tmp_accuracies = np.empty(num_runs, dtype=np.float64)
 | 
			
		||||
        tmp_accuracies_without_broken = np.empty(num_runs, dtype=np.float64)
 | 
			
		||||
        for i in range(num_runs):
 | 
			
		||||
            protocol, expected_labels = AWRExperiments.get_protocol(protocol_nr,
 | 
			
		||||
                                                                    num_messages=num_messages,
 | 
			
		||||
                                                                    num_broken_messages=broken,
 | 
			
		||||
                                                                    silent=True)
 | 
			
		||||
 | 
			
		||||
            AWRExperiments.run_format_finder_for_protocol(protocol)
 | 
			
		||||
            accuracy = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels)
 | 
			
		||||
            accuracy_without_broken = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels, broken)
 | 
			
		||||
            tmp_accuracies[i] = accuracy
 | 
			
		||||
            tmp_accuracies_without_broken[i] = accuracy_without_broken
 | 
			
		||||
 | 
			
		||||
        avg_accuracy = np.mean(tmp_accuracies)
 | 
			
		||||
        avg_accuracy_without_broken = np.mean(tmp_accuracies_without_broken)
 | 
			
		||||
 | 
			
		||||
        result.append((avg_accuracy, avg_accuracy_without_broken))
 | 
			
		||||
        print("Protocol {} with {} broken: {:>3}% {:>3}%".format(protocol_nr, broken, int(avg_accuracy),
 | 
			
		||||
                                                                 int(avg_accuracy_without_broken)))
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AWRExperiments(AWRETestCase):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_1() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="dead")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x1337"},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_2() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="dead01")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef24")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 72)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x1337"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 36},
 | 
			
		||||
                               sequence_number_increment=32,
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_3() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="1337")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef")
 | 
			
		||||
 | 
			
		||||
        checksum = GenericCRC.from_standard_checksum("CRC8 CCITT")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DATA, 10 * 8)
 | 
			
		||||
        mb.add_checksum_label(8, checksum)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb_ack.add_checksum_label(8, checksum)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_4() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="1337")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef")
 | 
			
		||||
 | 
			
		||||
        checksum = GenericCRC.from_standard_checksum("CRC16 CCITT")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data1")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DATA, 8 * 8)
 | 
			
		||||
        mb.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        mb2 = MessageTypeBuilder("data2")
 | 
			
		||||
        mb2.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.DATA, 64 * 8)
 | 
			
		||||
        mb2.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb_ack.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        mt1, mt2, mt3 = mb.message_type, mb2.message_type, mb_ack.message_type
 | 
			
		||||
 | 
			
		||||
        preamble = "10001000" * 2
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mt1, mt2, mt3],
 | 
			
		||||
                               syncs_by_mt={mt1: "0x9a7d", mt2: "0x9a7d", mt3: "0x9a7d"},
 | 
			
		||||
                               preambles_by_mt={mt1: preamble, mt2: preamble, mt3: preamble},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_5() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="1337")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef")
 | 
			
		||||
        carl = Participant("Carl", address_hex="cafe")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
 | 
			
		||||
                               participants=[alice, bob, carl])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_6() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="24")
 | 
			
		||||
        broadcast = Participant("Bob", address_hex="ff")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x8e88"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8},
 | 
			
		||||
                               participants=[alice, broadcast])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_7() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice", address_hex="313370")
 | 
			
		||||
        bob = Participant("Bob", address_hex="031337")
 | 
			
		||||
        charly = Participant("Charly", address_hex="110000")
 | 
			
		||||
        daniel = Participant("Daniel", address_hex="001100")
 | 
			
		||||
        # broadcast = Participant("Broadcast", address_hex="ff")     #TODO: Sometimes messages to broadcast
 | 
			
		||||
 | 
			
		||||
        checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.DATA, 8 * 8)
 | 
			
		||||
        mb.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
        mb_ack.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        mb_kex = MessageTypeBuilder("kex")
 | 
			
		||||
        mb_kex.add_label(FieldType.Function.PREAMBLE, 24)
 | 
			
		||||
        mb_kex.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_kex.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
        mb_kex.add_label(FieldType.Function.SRC_ADDRESS, 24)
 | 
			
		||||
        mb_kex.add_label(FieldType.Function.DATA, 64 * 8)
 | 
			
		||||
        mb_kex.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type, mb_kex.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x0420", mb_ack.message_type: "0x2222",
 | 
			
		||||
                                            mb_kex.message_type: "0x6767"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 4,
 | 
			
		||||
                                                mb_kex.message_type: "10" * 12},
 | 
			
		||||
                               participants=[alice, bob, charly, daniel])
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prepare_protocol_8() -> ProtocolGenerator:
 | 
			
		||||
        alice = Participant("Alice")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data1")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 4)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 4)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DATA, 8 * 542)
 | 
			
		||||
 | 
			
		||||
        mb2 = MessageTypeBuilder("data2")
 | 
			
		||||
        mb2.add_label(FieldType.Function.PREAMBLE, 4)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SYNC, 4)
 | 
			
		||||
        mb2.add_label(FieldType.Function.LENGTH, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.DATA, 8 * 260)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb2.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9", mb2.message_type: "0x9"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 2, mb2.message_type: "10" * 2},
 | 
			
		||||
                               sequence_number_increment=32,
 | 
			
		||||
                               participants=[alice],
 | 
			
		||||
                               little_endian=True)
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
 | 
			
		||||
    def test_export_to_latex(self):
 | 
			
		||||
        filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocols.tex")
 | 
			
		||||
        if os.path.isfile(filename):
 | 
			
		||||
            os.remove(filename)
 | 
			
		||||
 | 
			
		||||
        for i in range(1, 9):
 | 
			
		||||
            pg = getattr(self, "_prepare_protocol_" + str(i))()
 | 
			
		||||
            pg.export_to_latex(filename, i)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_protocol(cls, protocol_number: int, num_messages, num_broken_messages=0, silent=False):
 | 
			
		||||
        if protocol_number == 1:
 | 
			
		||||
            pg = cls._prepare_protocol_1()
 | 
			
		||||
        elif protocol_number == 2:
 | 
			
		||||
            pg = cls._prepare_protocol_2()
 | 
			
		||||
        elif protocol_number == 3:
 | 
			
		||||
            pg = cls._prepare_protocol_3()
 | 
			
		||||
        elif protocol_number == 4:
 | 
			
		||||
            pg = cls._prepare_protocol_4()
 | 
			
		||||
        elif protocol_number == 5:
 | 
			
		||||
            pg = cls._prepare_protocol_5()
 | 
			
		||||
        elif protocol_number == 6:
 | 
			
		||||
            pg = cls._prepare_protocol_6()
 | 
			
		||||
        elif protocol_number == 7:
 | 
			
		||||
            pg = cls._prepare_protocol_7()
 | 
			
		||||
        elif protocol_number == 8:
 | 
			
		||||
            pg = cls._prepare_protocol_8()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Unknown protocol number")
 | 
			
		||||
 | 
			
		||||
        messages_types_with_data_field = [mt for mt in pg.protocol.message_types
 | 
			
		||||
                                          if mt.get_first_label_with_type(FieldType.Function.DATA)]
 | 
			
		||||
        i = -1
 | 
			
		||||
        while len(pg.protocol.messages) < num_messages:
 | 
			
		||||
            i += 1
 | 
			
		||||
            source = pg.participants[i % len(pg.participants)]
 | 
			
		||||
            destination = pg.participants[(i + 1) % len(pg.participants)]
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                data_bytes = 8
 | 
			
		||||
            else:
 | 
			
		||||
                # data_bytes = 16
 | 
			
		||||
                data_bytes = 64
 | 
			
		||||
 | 
			
		||||
            if len(messages_types_with_data_field) == 0:
 | 
			
		||||
                # set data automatically
 | 
			
		||||
                data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
 | 
			
		||||
                pg.generate_message(data=data, source=source, destination=destination)
 | 
			
		||||
            else:
 | 
			
		||||
                # search for message type with right data length
 | 
			
		||||
                mt = messages_types_with_data_field[i % len(messages_types_with_data_field)]
 | 
			
		||||
                data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
 | 
			
		||||
                data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
 | 
			
		||||
                pg.generate_message(message_type=mt, data=data, source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
            ack_message_type = next((mt for mt in pg.protocol.message_types if "ack" in mt.name), None)
 | 
			
		||||
            if ack_message_type:
 | 
			
		||||
                pg.generate_message(message_type=ack_message_type, data="", source=destination, destination=source)
 | 
			
		||||
 | 
			
		||||
        for i in range(num_broken_messages):
 | 
			
		||||
            msg = pg.protocol.messages[i]
 | 
			
		||||
            pos = random.randint(0, len(msg.plain_bits) // 2)
 | 
			
		||||
            msg.plain_bits[pos:] = array.array("B",
 | 
			
		||||
                                               [random.randint(0, 1) for _ in range(len(msg.plain_bits) - pos)])
 | 
			
		||||
 | 
			
		||||
        if num_broken_messages == 0:
 | 
			
		||||
            cls.save_protocol("protocol{}_{}_messages".format(protocol_number, num_messages), pg, silent=silent)
 | 
			
		||||
        else:
 | 
			
		||||
            cls.save_protocol("protocol{}_{}_broken".format(protocol_number, num_broken_messages), pg, silent=silent)
 | 
			
		||||
 | 
			
		||||
        expected_message_types = [msg.message_type for msg in pg.protocol.messages]
 | 
			
		||||
 | 
			
		||||
        # Delete message type information -> no prior knowledge
 | 
			
		||||
        cls.clear_message_types(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        # Delete data labels if present
 | 
			
		||||
        for mt in expected_message_types:
 | 
			
		||||
            data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
 | 
			
		||||
            if data_lbl:
 | 
			
		||||
                mt.remove(data_lbl)
 | 
			
		||||
 | 
			
		||||
        return pg.protocol, expected_message_types
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def calculate_accuracy(messages, expected_labels, num_broken_messages=0):
 | 
			
		||||
        """
 | 
			
		||||
        Calculate the accuracy of labels compared to expected labels
 | 
			
		||||
        Accuracy is 100% when labels == expected labels
 | 
			
		||||
        Accuracy drops by 1 / len(expected_labels) for every expected label not present in labels
 | 
			
		||||
 | 
			
		||||
        :type messages: list of Message
 | 
			
		||||
        :type expected_labels: list of MessageType
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        accuracy = sum(len(set(expected_labels[i]) & set(messages[i].message_type)) / len(expected_labels[i])
 | 
			
		||||
                       for i in range(num_broken_messages, len(messages)))
 | 
			
		||||
        try:
 | 
			
		||||
            accuracy /= (len(messages) - num_broken_messages)
 | 
			
		||||
        except ZeroDivisionError:
 | 
			
		||||
            accuracy = 0
 | 
			
		||||
 | 
			
		||||
        return accuracy * 100
 | 
			
		||||
 | 
			
		||||
    def test_against_num_messages(self):
 | 
			
		||||
        num_messages = list(range(1, 24, 1))
 | 
			
		||||
        accuracies = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        protocols = [1, 2, 3, 4, 5, 6, 7, 8]
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        np.random.seed(0)
 | 
			
		||||
        for protocol_nr in protocols:
 | 
			
		||||
            for n in num_messages:
 | 
			
		||||
                protocol, expected_labels = self.get_protocol(protocol_nr, num_messages=n)
 | 
			
		||||
                self.run_format_finder_for_protocol(protocol)
 | 
			
		||||
 | 
			
		||||
                accuracy = self.calculate_accuracy(protocol.messages, expected_labels)
 | 
			
		||||
                accuracies["protocol {}".format(protocol_nr)].append(accuracy)
 | 
			
		||||
 | 
			
		||||
        self.__plot(num_messages, accuracies, xlabel="Number of messages", ylabel="Accuracy in %", grid=True)
 | 
			
		||||
        self.__export_to_csv("/tmp/accuray-vs-messages", num_messages, accuracies)
 | 
			
		||||
 | 
			
		||||
    def test_against_error(self):
 | 
			
		||||
        Engine._DEBUG_ = False
 | 
			
		||||
        Preprocessor._DEBUG_ = False
 | 
			
		||||
 | 
			
		||||
        num_runs = 100
 | 
			
		||||
 | 
			
		||||
        num_messages = 30
 | 
			
		||||
        num_broken_messages = list(range(0, num_messages + 1))
 | 
			
		||||
        accuracies = defaultdict(list)
 | 
			
		||||
        accuracies_without_broken = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        protocols = [1, 2, 3, 4, 5, 6, 7, 8]
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        np.random.seed(0)
 | 
			
		||||
 | 
			
		||||
        with multiprocessing.Pool() as p:
 | 
			
		||||
            result = p.starmap(run_for_num_broken,
 | 
			
		||||
                               [(i, num_broken_messages, num_messages, num_runs) for i in protocols])
 | 
			
		||||
            for i, acc in enumerate(result):
 | 
			
		||||
                accuracies["protocol {}".format(i + 1)] = [a[0] for a in acc]
 | 
			
		||||
                accuracies_without_broken["protocol {}".format(i + 1)] = [a[1] for a in acc]
 | 
			
		||||
 | 
			
		||||
        self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies,
 | 
			
		||||
                    title="Overall Accuracy vs percentage of broken messages",
 | 
			
		||||
                    xlabel="Broken messages in %",
 | 
			
		||||
                    ylabel="Accuracy in %", grid=True)
 | 
			
		||||
        self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies_without_broken,
 | 
			
		||||
                    title=" Accuracy of unbroken vs percentage of broken messages",
 | 
			
		||||
                    xlabel="Broken messages in %",
 | 
			
		||||
                    ylabel="Accuracy in %", grid=True)
 | 
			
		||||
        self.__export_to_csv("/tmp/accuray-vs-error", num_broken_messages, accuracies, relative=num_messages)
 | 
			
		||||
        self.__export_to_csv("/tmp/accuray-vs-error-without-broken", num_broken_messages, accuracies_without_broken,
 | 
			
		||||
                             relative=num_messages)
 | 
			
		||||
 | 
			
		||||
    def test_performance(self):
 | 
			
		||||
        Engine._DEBUG_ = False
 | 
			
		||||
        Preprocessor._DEBUG_ = False
 | 
			
		||||
 | 
			
		||||
        num_messages = list(range(200, 205, 5))
 | 
			
		||||
        protocols = [1]
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        np.random.seed(0)
 | 
			
		||||
 | 
			
		||||
        performances = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        for protocol_nr in protocols:
 | 
			
		||||
            print("Running for protocol", protocol_nr)
 | 
			
		||||
            for messages in num_messages:
 | 
			
		||||
                protocol, _ = self.get_protocol(protocol_nr, messages, silent=True)
 | 
			
		||||
 | 
			
		||||
                t = time.time()
 | 
			
		||||
                self.run_format_finder_for_protocol(protocol)
 | 
			
		||||
                performances["protocol {}".format(protocol_nr)].append(time.time() - t)
 | 
			
		||||
 | 
			
		||||
        # self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
 | 
			
		||||
 | 
			
		||||
    def test_performance_real_protocols(self):
 | 
			
		||||
        Engine._DEBUG_ = False
 | 
			
		||||
        Preprocessor._DEBUG_ = False
 | 
			
		||||
 | 
			
		||||
        num_runs = 100
 | 
			
		||||
 | 
			
		||||
        num_messages = list(range(8, 512, 4))
 | 
			
		||||
        protocol_names = ["enocean", "homematic", "rwe"]
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        np.random.seed(0)
 | 
			
		||||
 | 
			
		||||
        performances = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        for protocol_name in protocol_names:
 | 
			
		||||
            for messages in num_messages:
 | 
			
		||||
                if protocol_name == "homematic":
 | 
			
		||||
                    protocol = self.generate_homematic(messages, save_protocol=False)
 | 
			
		||||
                elif protocol_name == "enocean":
 | 
			
		||||
                    protocol = self.generate_enocean(messages, save_protocol=False)
 | 
			
		||||
                elif protocol_name == "rwe":
 | 
			
		||||
                    protocol = self.generate_rwe(messages, save_protocol=False)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise ValueError("Unknown protocol name")
 | 
			
		||||
 | 
			
		||||
                tmp_performances = np.empty(num_runs, dtype=np.float64)
 | 
			
		||||
                for i in range(num_runs):
 | 
			
		||||
                    print("\r{0} with {1:02d} messages ({2}/{3} runs)".format(protocol_name, messages, i + 1, num_runs),
 | 
			
		||||
                          flush=True, end="")
 | 
			
		||||
 | 
			
		||||
                    t = time.time()
 | 
			
		||||
                    self.run_format_finder_for_protocol(protocol)
 | 
			
		||||
                    tmp_performances[i] = time.time() - t
 | 
			
		||||
                    self.clear_message_types(protocol.messages)
 | 
			
		||||
 | 
			
		||||
                mean_performance = tmp_performances.mean()
 | 
			
		||||
                print(" {:.2f}s".format(mean_performance))
 | 
			
		||||
                performances["{}".format(protocol_name)].append(mean_performance)
 | 
			
		||||
 | 
			
		||||
        self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True)
 | 
			
		||||
        self.__export_to_csv("/tmp/performance.csv", num_messages, performances)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __export_to_csv(filename: str, x: list, y: dict, relative=None):
 | 
			
		||||
        if not filename.endswith(".csv"):
 | 
			
		||||
            filename += ".csv"
 | 
			
		||||
 | 
			
		||||
        with open(filename, "w") as f:
 | 
			
		||||
            f.write("N,")
 | 
			
		||||
            if relative is not None:
 | 
			
		||||
                f.write("NRel,")
 | 
			
		||||
            for y_cap in sorted(y):
 | 
			
		||||
                f.write(y_cap + ",")
 | 
			
		||||
            f.write("\n")
 | 
			
		||||
 | 
			
		||||
            for i, x_val in enumerate(x):
 | 
			
		||||
                f.write("{},".format(x_val))
 | 
			
		||||
                if relative is not None:
 | 
			
		||||
                    f.write("{},".format(100 * x_val / relative))
 | 
			
		||||
 | 
			
		||||
                for y_cap in sorted(y):
 | 
			
		||||
                    f.write("{},".format(y[y_cap][i]))
 | 
			
		||||
                f.write("\n")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __plot(x: list, y: dict, xlabel: str, ylabel: str, grid=False, title=None):
 | 
			
		||||
        plt.xlabel(xlabel)
 | 
			
		||||
        plt.ylabel(ylabel)
 | 
			
		||||
 | 
			
		||||
        for y_cap, y_values in sorted(y.items()):
 | 
			
		||||
            plt.plot(x, y_values, label=y_cap)
 | 
			
		||||
 | 
			
		||||
        if grid:
 | 
			
		||||
            plt.grid()
 | 
			
		||||
 | 
			
		||||
        if title:
 | 
			
		||||
            plt.title(title)
 | 
			
		||||
 | 
			
		||||
        plt.legend()
 | 
			
		||||
        plt.show()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def run_format_finder_for_protocol(protocol: ProtocolAnalyzer):
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        for msg_type, indices in ff.existing_message_types.items():
 | 
			
		||||
            for i in indices:
 | 
			
		||||
                protocol.messages[i].message_type = msg_type
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def generate_homematic(cls, num_messages: int, save_protocol=True):
 | 
			
		||||
        mb_m_frame = MessageTypeBuilder("mframe")
 | 
			
		||||
        mb_c_frame = MessageTypeBuilder("cframe")
 | 
			
		||||
        mb_r_frame = MessageTypeBuilder("rframe")
 | 
			
		||||
        mb_a_frame = MessageTypeBuilder("aframe")
 | 
			
		||||
 | 
			
		||||
        participants = [Participant("CCU", address_hex="3927cc"), Participant("Switch", address_hex="3101cc")]
 | 
			
		||||
 | 
			
		||||
        checksum = GenericCRC.from_standard_checksum("CRC16 CC1101")
 | 
			
		||||
        for mb_builder in [mb_m_frame, mb_c_frame, mb_r_frame, mb_a_frame]:
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.PREAMBLE, 32)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.SYNC, 32)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.TYPE, 16)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.SRC_ADDRESS, 24)
 | 
			
		||||
            mb_builder.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
            if mb_builder.name == "mframe":
 | 
			
		||||
                mb_builder.add_label(FieldType.Function.DATA, 16, name="command")
 | 
			
		||||
            elif mb_builder.name == "cframe":
 | 
			
		||||
                mb_builder.add_label(FieldType.Function.DATA, 16 * 4, name="command+challenge+magic")
 | 
			
		||||
            elif mb_builder.name == "rframe":
 | 
			
		||||
                mb_builder.add_label(FieldType.Function.DATA, 32 * 4, name="cipher")
 | 
			
		||||
            elif mb_builder.name == "aframe":
 | 
			
		||||
                mb_builder.add_label(FieldType.Function.DATA, 10 * 4, name="command + auth")
 | 
			
		||||
            mb_builder.add_checksum_label(16, checksum)
 | 
			
		||||
 | 
			
		||||
        message_types = [mb_m_frame.message_type, mb_c_frame.message_type, mb_r_frame.message_type,
 | 
			
		||||
                         mb_a_frame.message_type]
 | 
			
		||||
        preamble = "0xaaaaaaaa"
 | 
			
		||||
        sync = "0xe9cae9ca"
 | 
			
		||||
        initial_sequence_number = 36
 | 
			
		||||
        pg = ProtocolGenerator(message_types, participants,
 | 
			
		||||
                               preambles_by_mt={mt: preamble for mt in message_types},
 | 
			
		||||
                               syncs_by_mt={mt: sync for mt in message_types},
 | 
			
		||||
                               sequence_numbers={mt: initial_sequence_number for mt in message_types},
 | 
			
		||||
                               message_type_codes={mb_m_frame.message_type: 42560,
 | 
			
		||||
                                                   mb_c_frame.message_type: 40962,
 | 
			
		||||
                                                   mb_r_frame.message_type: 40963,
 | 
			
		||||
                                                   mb_a_frame.message_type: 32770})
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            mt = pg.message_types[i % 4]
 | 
			
		||||
            data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length
 | 
			
		||||
            data = "".join(random.choice(["0", "1"]) for _ in range(data_length))
 | 
			
		||||
            pg.generate_message(mt, data, source=pg.participants[i % 2], destination=pg.participants[(i + 1) % 2])
 | 
			
		||||
 | 
			
		||||
        if save_protocol:
 | 
			
		||||
            cls.save_protocol("homematic", pg)
 | 
			
		||||
 | 
			
		||||
        cls.clear_message_types(pg.messages)
 | 
			
		||||
        return pg.protocol
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def generate_enocean(cls, num_messages: int, save_protocol=True):
 | 
			
		||||
        filename = get_path_for_data_file("enocean_bits.txt")
 | 
			
		||||
        enocean_bits = []
 | 
			
		||||
        with open(filename, "r") as f:
 | 
			
		||||
            for line in map(str.strip, f):
 | 
			
		||||
                enocean_bits.append(line)
 | 
			
		||||
 | 
			
		||||
        protocol = ProtocolAnalyzer(None)
 | 
			
		||||
        message_type = MessageType("empty")
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            msg = Message.from_plain_bits_str(enocean_bits[i % len(enocean_bits)])
 | 
			
		||||
            msg.message_type = message_type
 | 
			
		||||
            protocol.messages.append(msg)
 | 
			
		||||
 | 
			
		||||
        if save_protocol:
 | 
			
		||||
            cls.save_protocol("enocean", protocol)
 | 
			
		||||
 | 
			
		||||
        return protocol
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def generate_rwe(cls, num_messages: int, save_protocol=True):
 | 
			
		||||
        proto_file = get_path_for_data_file("rwe.proto.xml")
 | 
			
		||||
        protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
 | 
			
		||||
        protocol.from_xml_file(filename=proto_file, read_bits=True)
 | 
			
		||||
        messages = protocol.messages
 | 
			
		||||
 | 
			
		||||
        result = ProtocolAnalyzer(None)
 | 
			
		||||
        message_type = MessageType("empty")
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            msg = messages[i % len(messages)]  # type: Message
 | 
			
		||||
            msg.message_type = message_type
 | 
			
		||||
            result.messages.append(msg)
 | 
			
		||||
 | 
			
		||||
        if save_protocol:
 | 
			
		||||
            cls.save_protocol("rwe", result)
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def test_export_latex_table(self):
 | 
			
		||||
        def bold_latex(s):
 | 
			
		||||
            return r"\textbf{" + str(s) + r"}"
 | 
			
		||||
 | 
			
		||||
        comments = {
 | 
			
		||||
            1: "common protocol",
 | 
			
		||||
            2: "unusual field sizes",
 | 
			
		||||
            3: "contains ack and CRC8 CCITT",
 | 
			
		||||
            4: "contains ack and CRC16 CCITT",
 | 
			
		||||
            5: "three participants with ack frame",
 | 
			
		||||
            6: "short address",
 | 
			
		||||
            7: "four participants, varying preamble size, varying sync words",
 | 
			
		||||
            8: "nibble fields + LE"
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        bold = {i: defaultdict(bool) for i in range(1, 9)}
 | 
			
		||||
        bold[2][FieldType.Function.PREAMBLE] = True
 | 
			
		||||
        bold[2][FieldType.Function.SRC_ADDRESS] = True
 | 
			
		||||
        bold[2][FieldType.Function.DST_ADDRESS] = True
 | 
			
		||||
 | 
			
		||||
        bold[3][FieldType.Function.CHECKSUM] = True
 | 
			
		||||
 | 
			
		||||
        bold[4][FieldType.Function.CHECKSUM] = True
 | 
			
		||||
 | 
			
		||||
        bold[6][FieldType.Function.SRC_ADDRESS] = True
 | 
			
		||||
 | 
			
		||||
        bold[7][FieldType.Function.PREAMBLE] = True
 | 
			
		||||
        bold[7][FieldType.Function.SYNC] = True
 | 
			
		||||
        bold[7][FieldType.Function.SRC_ADDRESS] = True
 | 
			
		||||
        bold[7][FieldType.Function.DST_ADDRESS] = True
 | 
			
		||||
 | 
			
		||||
        bold[8][FieldType.Function.PREAMBLE] = True
 | 
			
		||||
        bold[8][FieldType.Function.SYNC] = True
 | 
			
		||||
 | 
			
		||||
        filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocol_table.tex")
 | 
			
		||||
        rowcolors = [r"\rowcolor{black!10}", r"\rowcolor{black!20}"]
 | 
			
		||||
 | 
			
		||||
        with open(filename, "w") as f:
 | 
			
		||||
            f.write(r"\begin{table*}[!h]" + "\n")
 | 
			
		||||
            f.write(
 | 
			
		||||
                "\t" + r"\caption{Properties of tested protocols whereby $\times$ means field is not present and $N_P$ is the number of participants.}" + "\n")
 | 
			
		||||
            f.write("\t" + r"\label{tab:protocols}" + "\n")
 | 
			
		||||
            f.write("\t" + r"\centering" + "\n")
 | 
			
		||||
            f.write("\t" + r"\begin{tabularx}{\linewidth}{cp{2.5cm}llcccccccc}" + "\n")
 | 
			
		||||
            f.write("\t\t" + r"\hline" + "\n")
 | 
			
		||||
            f.write("\t\t" + r"\rowcolor{black!90}" + "\n")
 | 
			
		||||
            f.write("\t\t" + r"\textcolor{white}{\textbf{\#}} & "
 | 
			
		||||
                             r"\textcolor{white}{\textbf{Comment}} & "
 | 
			
		||||
                             r"\textcolor{white}{$\mathbf{ N_P }$} & "
 | 
			
		||||
                             r"\textcolor{white}{\textbf{Message}} & "
 | 
			
		||||
                             r"\textcolor{white}{\textbf{Even/odd}} & "
 | 
			
		||||
                             r"\multicolumn{7}{c}{\textcolor{white}{\textbf{Size of field in bit (BE=Big Endian, LE=Little Endian)}}}\\"
 | 
			
		||||
                             "\n\t\t"
 | 
			
		||||
                             r"\rowcolor{black!90}"
 | 
			
		||||
                             "\n\t\t"
 | 
			
		||||
                             r"& & & \textcolor{white}{\textbf{Type}} & \textcolor{white}{\textbf{message data}} &"
 | 
			
		||||
                             r"\textcolor{white}{Preamble} & "
 | 
			
		||||
                             r"\textcolor{white}{Sync} & "
 | 
			
		||||
                             r"\textcolor{white}{Length}  & "
 | 
			
		||||
                             r"\textcolor{white}{SRC} & "
 | 
			
		||||
                             r"\textcolor{white}{DST} & "
 | 
			
		||||
                             r"\textcolor{white}{SEQ Nr} & "
 | 
			
		||||
                             r"\textcolor{white}{CRC}  \\" + "\n")
 | 
			
		||||
            f.write("\t\t" + r"\hline" + "\n")
 | 
			
		||||
 | 
			
		||||
            rowcolor_index = 0
 | 
			
		||||
            for i in range(1, 9):
 | 
			
		||||
                pg = getattr(self, "_prepare_protocol_" + str(i))()
 | 
			
		||||
                assert isinstance(pg, ProtocolGenerator)
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    data1 = next(mt for mt in pg.message_types if mt.name == "data1")
 | 
			
		||||
                    data2 = next(mt for mt in pg.message_types if mt.name == "data2")
 | 
			
		||||
 | 
			
		||||
                    data1_len = data1.get_first_label_with_type(FieldType.Function.DATA).length // 8
 | 
			
		||||
                    data2_len = data2.get_first_label_with_type(FieldType.Function.DATA).length // 8
 | 
			
		||||
 | 
			
		||||
                except StopIteration:
 | 
			
		||||
                    data1_len, data2_len = 8, 64
 | 
			
		||||
 | 
			
		||||
                rowcolor = rowcolors[rowcolor_index % len(rowcolors)]
 | 
			
		||||
                rowcount = 0
 | 
			
		||||
                for j, mt in enumerate(pg.message_types):
 | 
			
		||||
                    if mt.name == "data2":
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    rowcount += 1
 | 
			
		||||
                    if j == 0:
 | 
			
		||||
                        protocol_nr, participants = str(i), len(pg.participants)
 | 
			
		||||
                        if participants > 2:
 | 
			
		||||
                            participants = bold_latex(participants)
 | 
			
		||||
                    else:
 | 
			
		||||
                        protocol_nr, participants = " ", " "
 | 
			
		||||
 | 
			
		||||
                    f.write("\t\t" + rowcolor + "\n")
 | 
			
		||||
 | 
			
		||||
                    if len(pg.message_types) == 1 or (
 | 
			
		||||
                            mt.name == "data1" and "ack" not in {m.name for m in pg.message_types}):
 | 
			
		||||
                        f.write("\t\t{} & {} & {} & {} &".format(protocol_nr, comments[i], participants,
 | 
			
		||||
                                                                 mt.name.replace("1", "")))
 | 
			
		||||
                    elif j == len(pg.message_types) - 1:
 | 
			
		||||
                        f.write(
 | 
			
		||||
                            "\t\t{} & \\multirow{{{}}}{{\\linewidth}}{{{}}} & {} & {} &".format(protocol_nr, -rowcount,
 | 
			
		||||
                                                                                                comments[i],
 | 
			
		||||
                                                                                                participants,
 | 
			
		||||
                                                                                                mt.name.replace("1",
 | 
			
		||||
                                                                                                                "")))
 | 
			
		||||
                    else:
 | 
			
		||||
                        f.write("\t\t{} & & {} & {} &".format(protocol_nr, participants, mt.name.replace("1", "")))
 | 
			
		||||
                    data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA)
 | 
			
		||||
 | 
			
		||||
                    if mt.name == "data1" or mt.name == "data2":
 | 
			
		||||
                        f.write("{}/{} byte &".format(data1_len, data2_len))
 | 
			
		||||
                    elif mt.name == "data" and data_lbl is None:
 | 
			
		||||
                        f.write("{}/{} byte &".format(data1_len, data2_len))
 | 
			
		||||
                    elif data_lbl is not None:
 | 
			
		||||
                        f.write("{0}/{0} byte & ".format(data_lbl.length // 8))
 | 
			
		||||
                    else:
 | 
			
		||||
                        f.write(r"$ \times $ & ")
 | 
			
		||||
 | 
			
		||||
                    for t in (FieldType.Function.PREAMBLE, FieldType.Function.SYNC, FieldType.Function.LENGTH,
 | 
			
		||||
                              FieldType.Function.SRC_ADDRESS, FieldType.Function.DST_ADDRESS,
 | 
			
		||||
                              FieldType.Function.SEQUENCE_NUMBER,
 | 
			
		||||
                              FieldType.Function.CHECKSUM):
 | 
			
		||||
                        lbl = mt.get_first_label_with_type(t)
 | 
			
		||||
                        if lbl is not None:
 | 
			
		||||
                            if bold[i][lbl.field_type.function]:
 | 
			
		||||
                                f.write(bold_latex(lbl.length))
 | 
			
		||||
                            else:
 | 
			
		||||
                                f.write(str(lbl.length))
 | 
			
		||||
                            if lbl.length > 8 and t in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER):
 | 
			
		||||
                                f.write(" ({})".format(bold_latex("LE") if pg.little_endian else "BE"))
 | 
			
		||||
                        else:
 | 
			
		||||
                            f.write(r"$ \times $")
 | 
			
		||||
 | 
			
		||||
                        if t != FieldType.Function.CHECKSUM:
 | 
			
		||||
                            f.write(" & ")
 | 
			
		||||
                        else:
 | 
			
		||||
                            f.write(r"\\" + "\n")
 | 
			
		||||
 | 
			
		||||
                rowcolor_index += 1
 | 
			
		||||
 | 
			
		||||
            f.write("\t" + r"\end{tabularx}" + "\n")
 | 
			
		||||
 | 
			
		||||
            f.write(r"\end{table*}" + "\n")
 | 
			
		||||
							
								
								
									
										179
									
								
								Software/urh/tests/awre/TestAWREHistograms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								Software/urh/tests/awre/TestAWREHistograms.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,179 @@
 | 
			
		||||
import random
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.Histogram import Histogram
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
 | 
			
		||||
SHOW_PLOTS = True
 | 
			
		||||
 | 
			
		||||
class TestAWREHistograms(AWRETestCase):
 | 
			
		||||
    def test_very_simple_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a very simple protocol consisting just of a preamble, sync and some random data
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("very_simple_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
 | 
			
		||||
        num_messages = 10
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a"})
 | 
			
		||||
        for _ in range(num_messages):
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 255), 8))
 | 
			
		||||
 | 
			
		||||
        self.save_protocol("very_simple", pg)
 | 
			
		||||
 | 
			
		||||
        h = Histogram(FormatFinder.get_bitvectors_from_messages(pg.protocol.messages))
 | 
			
		||||
        if SHOW_PLOTS:
 | 
			
		||||
            h.plot()
 | 
			
		||||
 | 
			
		||||
    def test_simple_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a simple protocol with preamble, sync and length field and some random data
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("simple_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
 | 
			
		||||
        num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a9d"})
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for _ in range(num_messages):
 | 
			
		||||
                pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** data_length - 1), data_length))
 | 
			
		||||
 | 
			
		||||
        self.save_protocol("simple", pg)
 | 
			
		||||
 | 
			
		||||
        plt.subplot("221")
 | 
			
		||||
        plt.title("All messages")
 | 
			
		||||
        format_finder = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        for i, sync_end in enumerate(format_finder.sync_ends):
 | 
			
		||||
            self.assertEqual(sync_end, 24, msg=str(i))
 | 
			
		||||
 | 
			
		||||
        h = Histogram(format_finder.bitvectors)
 | 
			
		||||
        h.subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
 | 
			
		||||
        bitvectors_by_length = defaultdict(list)
 | 
			
		||||
        for bitvector in bitvectors:
 | 
			
		||||
            bitvectors_by_length[len(bitvector)].append(bitvector)
 | 
			
		||||
 | 
			
		||||
        for i, (message_length, bitvectors) in enumerate(bitvectors_by_length.items()):
 | 
			
		||||
            plt.subplot(2, 2, i + 2)
 | 
			
		||||
            plt.title("Messages with length {} ({})".format(message_length, len(bitvectors)))
 | 
			
		||||
            Histogram(bitvectors).subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        if SHOW_PLOTS:
 | 
			
		||||
            plt.show()
 | 
			
		||||
 | 
			
		||||
    def test_medium_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a protocol with preamble, sync, length field, 2 participants and addresses and seq nr and random data
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("medium_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        alice = Participant("Alice", "A", "1234", color_index=0)
 | 
			
		||||
        bob = Participant("Bob", "B", "5a9d", color_index=1)
 | 
			
		||||
 | 
			
		||||
        num_messages = 100
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x1c"}, little_endian=False)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            len_data = random.randint(1, 5)
 | 
			
		||||
            data = "".join(pg.decimal_to_bits(random.randint(0, 2 ** 8 - 1), 8) for _ in range(len_data))
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, dest = alice, bob
 | 
			
		||||
            else:
 | 
			
		||||
                source, dest = bob, alice
 | 
			
		||||
            pg.generate_message(data=data, source=source, destination=dest)
 | 
			
		||||
 | 
			
		||||
        self.save_protocol("medium", pg)
 | 
			
		||||
 | 
			
		||||
        plt.subplot(2, 2, 1)
 | 
			
		||||
        plt.title("All messages")
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
 | 
			
		||||
        h = Histogram(bitvectors)
 | 
			
		||||
        h.subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        for i, (participant, bitvectors) in enumerate(
 | 
			
		||||
                sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
 | 
			
		||||
            plt.subplot(2, 2, i + 3)
 | 
			
		||||
            plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
 | 
			
		||||
            Histogram(bitvectors).subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        if SHOW_PLOTS:
 | 
			
		||||
            plt.show()
 | 
			
		||||
 | 
			
		||||
    def get_bitvectors_by_participant(self, messages):
 | 
			
		||||
        import numpy as np
 | 
			
		||||
        result = defaultdict(list)
 | 
			
		||||
        for msg in messages:  # type: Message
 | 
			
		||||
            result[msg.participant].append(np.array(msg.decoded_bits, dtype=np.uint8, order="C"))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def test_ack_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a protocol with acks
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        alice = Participant("Alice", "A", "1234", color_index=0)
 | 
			
		||||
        bob = Participant("Bob", "B", "5a9d", color_index=1)
 | 
			
		||||
 | 
			
		||||
        num_messages = 50
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0xbf", mb_ack.message_type: "0xbf"},
 | 
			
		||||
                               little_endian=False)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, dest = alice, bob
 | 
			
		||||
            else:
 | 
			
		||||
                source, dest = bob, alice
 | 
			
		||||
            pg.generate_message(data="0xffff", source=source, destination=dest)
 | 
			
		||||
            pg.generate_message(data="", source=dest, destination=source, message_type=mb_ack.message_type)
 | 
			
		||||
 | 
			
		||||
        self.save_protocol("proto_with_acks", pg)
 | 
			
		||||
 | 
			
		||||
        plt.subplot(2, 2, 1)
 | 
			
		||||
        plt.title("All messages")
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)
 | 
			
		||||
        h = Histogram(bitvectors)
 | 
			
		||||
        h.subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        for i, (participant, bitvectors) in enumerate(
 | 
			
		||||
                sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())):
 | 
			
		||||
            plt.subplot(2, 2, i + 3)
 | 
			
		||||
            plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors)))
 | 
			
		||||
            Histogram(bitvectors).subplot_on(plt)
 | 
			
		||||
 | 
			
		||||
        if SHOW_PLOTS:
 | 
			
		||||
            plt.show()
 | 
			
		||||
							
								
								
									
										0
									
								
								Software/urh/tests/awre/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								Software/urh/tests/awre/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										386
									
								
								Software/urh/tests/awre/test_address_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										386
									
								
								Software/urh/tests/awre/test_address_engine.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,386 @@
 | 
			
		||||
import random
 | 
			
		||||
from array import array
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from tests.utils_testing import get_path_for_data_file
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.awre.engines.AddressEngine import AddressEngine
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Message import Message
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
from urh.util import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAddressEngine(AWRETestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        self.alice = Participant("Alice", "A", address_hex="1234")
 | 
			
		||||
        self.bob = Participant("Bob", "B", address_hex="cafe")
 | 
			
		||||
 | 
			
		||||
    def test_one_participant(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a simple protocol with
 | 
			
		||||
        preamble, sync and length field (8 bit) and some random data
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("simple_address_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"},
 | 
			
		||||
                               participants=[self.alice])
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for i in range(num_messages):
 | 
			
		||||
                pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("address_one_participant", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
 | 
			
		||||
        address_dict = address_engine.find_addresses()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(address_dict), 0)
 | 
			
		||||
 | 
			
		||||
    def test_two_participants(self):
 | 
			
		||||
        mb = MessageTypeBuilder("address_two_participants")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 50
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"},
 | 
			
		||||
                               participants=[self.alice, self.bob])
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = self.alice, self.bob
 | 
			
		||||
                data_length = 8
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = self.bob, self.alice
 | 
			
		||||
                data_length = 16
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("address_two_participants", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
 | 
			
		||||
        address_dict = address_engine.find_addresses()
 | 
			
		||||
        self.assertEqual(len(address_dict), 2)
 | 
			
		||||
        addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
 | 
			
		||||
        addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_2)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_2)
 | 
			
		||||
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        self.assertEqual(len(ff.known_participant_addresses), 0)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.known_participant_addresses), 2)
 | 
			
		||||
        self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]),
 | 
			
		||||
                      map(bytes, ff.known_participant_addresses.values()))
 | 
			
		||||
        self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]),
 | 
			
		||||
                      map(bytes, ff.known_participant_addresses.values()))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(dst_addr)
 | 
			
		||||
        self.assertEqual(dst_addr.start, 32)
 | 
			
		||||
        self.assertEqual(dst_addr.length, 16)
 | 
			
		||||
        src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(src_addr)
 | 
			
		||||
        self.assertEqual(src_addr.start, 48)
 | 
			
		||||
        self.assertEqual(src_addr.length, 16)
 | 
			
		||||
 | 
			
		||||
    def test_two_participants_with_ack_messages(self):
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 50
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
 | 
			
		||||
                               participants=[self.alice, self.bob])
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = self.alice, self.bob
 | 
			
		||||
                data_length = 8
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = self.bob, self.alice
 | 
			
		||||
                data_length = 16
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
 | 
			
		||||
                                source=source, destination=destination)
 | 
			
		||||
            pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("address_two_participants_with_acks", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
 | 
			
		||||
        address_dict = address_engine.find_addresses()
 | 
			
		||||
        self.assertEqual(len(address_dict), 2)
 | 
			
		||||
        addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
 | 
			
		||||
        addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_2)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_2)
 | 
			
		||||
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 2)
 | 
			
		||||
        mt = ff.message_types[1]
 | 
			
		||||
        dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(dst_addr)
 | 
			
		||||
        self.assertEqual(dst_addr.start, 32)
 | 
			
		||||
        self.assertEqual(dst_addr.length, 16)
 | 
			
		||||
        src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(src_addr)
 | 
			
		||||
        self.assertEqual(src_addr.start, 48)
 | 
			
		||||
        self.assertEqual(src_addr.length, 16)
 | 
			
		||||
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(dst_addr)
 | 
			
		||||
        self.assertEqual(dst_addr.start, 32)
 | 
			
		||||
        self.assertEqual(dst_addr.length, 16)
 | 
			
		||||
 | 
			
		||||
    def test_two_participants_with_ack_messages_and_type(self):
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.TYPE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 50
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
 | 
			
		||||
                               participants=[self.alice, self.bob])
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = self.alice, self.bob
 | 
			
		||||
                data_length = 8
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = self.bob, self.alice
 | 
			
		||||
                data_length = 16
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
 | 
			
		||||
                                source=source, destination=destination)
 | 
			
		||||
            pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("address_two_participants_with_acks_and_types", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        address_engine = AddressEngine(ff.hexvectors, ff.participant_indices)
 | 
			
		||||
        address_dict = address_engine.find_addresses()
 | 
			
		||||
        self.assertEqual(len(address_dict), 2)
 | 
			
		||||
        addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0]))
 | 
			
		||||
        addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1]))
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.alice.address_hex, addresses_2)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_1)
 | 
			
		||||
        self.assertIn(self.bob.address_hex, addresses_2)
 | 
			
		||||
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 2)
 | 
			
		||||
        mt = ff.message_types[1]
 | 
			
		||||
        dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(dst_addr)
 | 
			
		||||
        self.assertEqual(dst_addr.start, 40)
 | 
			
		||||
        self.assertEqual(dst_addr.length, 16)
 | 
			
		||||
        src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(src_addr)
 | 
			
		||||
        self.assertEqual(src_addr.start, 56)
 | 
			
		||||
        self.assertEqual(src_addr.length, 16)
 | 
			
		||||
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
        self.assertIsNotNone(dst_addr)
 | 
			
		||||
        self.assertEqual(dst_addr.start, 32)
 | 
			
		||||
        self.assertEqual(dst_addr.length, 16)
 | 
			
		||||
 | 
			
		||||
    def test_three_participants_with_ack(self):
 | 
			
		||||
        alice = Participant("Alice", address_hex="1337")
 | 
			
		||||
        bob = Participant("Bob", address_hex="4711")
 | 
			
		||||
        carl = Participant("Carl", address_hex="cafe")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8},
 | 
			
		||||
                               participants=[alice, bob, carl])
 | 
			
		||||
 | 
			
		||||
        i = -1
 | 
			
		||||
        while len(pg.protocol.messages) < 20:
 | 
			
		||||
            i += 1
 | 
			
		||||
            source = pg.participants[i % len(pg.participants)]
 | 
			
		||||
            destination = pg.participants[(i + 1) % len(pg.participants)]
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                data_bytes = 8
 | 
			
		||||
            else:
 | 
			
		||||
                data_bytes = 16
 | 
			
		||||
 | 
			
		||||
            data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8))
 | 
			
		||||
            pg.generate_message(data=data, source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
            if "ack" in (msg_type.name for msg_type in pg.protocol.message_types):
 | 
			
		||||
                pg.generate_message(message_type=1, data="", source=destination, destination=source)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        self.assertEqual(len(ff.known_participant_addresses), 0)
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        # Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses
 | 
			
		||||
        # IN CORRECT ORDER!
 | 
			
		||||
        self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
 | 
			
		||||
        self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
 | 
			
		||||
        self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe")
 | 
			
		||||
 | 
			
		||||
    def test_protocol_with_acks_and_checksum(self):
 | 
			
		||||
        proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml")
 | 
			
		||||
        protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
 | 
			
		||||
        protocol.from_xml_file(filename=proto_file, read_bits=True)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
        self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337")
 | 
			
		||||
        self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711")
 | 
			
		||||
 | 
			
		||||
        for mt in ff.message_types:
 | 
			
		||||
            preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
            self.assertEqual(preamble.start, 0)
 | 
			
		||||
            self.assertEqual(preamble.length, 16)
 | 
			
		||||
            sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
            self.assertEqual(sync.start, 16)
 | 
			
		||||
            self.assertEqual(sync.length, 16)
 | 
			
		||||
            length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
            self.assertEqual(length.start, 32)
 | 
			
		||||
            self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_address_engine_performance(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True)
 | 
			
		||||
 | 
			
		||||
        engine = AddressEngine(ff.hexvectors, ff.participant_indices)
 | 
			
		||||
        engine.find()
 | 
			
		||||
 | 
			
		||||
    def test_paper_example(self):
 | 
			
		||||
        alice = Participant("Alice", "A")
 | 
			
		||||
        bob = Participant("Bob", "B")
 | 
			
		||||
        participants = [alice, bob]
 | 
			
		||||
        msg1 = Message.from_plain_hex_str("aabb1234")
 | 
			
		||||
        msg1.participant = alice
 | 
			
		||||
        msg2 = Message.from_plain_hex_str("aabb6789")
 | 
			
		||||
        msg2.participant = alice
 | 
			
		||||
        msg3 = Message.from_plain_hex_str("bbaa4711")
 | 
			
		||||
        msg3.participant = bob
 | 
			
		||||
        msg4 = Message.from_plain_hex_str("bbaa1337")
 | 
			
		||||
        msg4.participant = bob
 | 
			
		||||
 | 
			
		||||
        protocol = ProtocolAnalyzer(None)
 | 
			
		||||
        protocol.messages.extend([msg1, msg2, msg3, msg4])
 | 
			
		||||
        #self.save_protocol("paper_example", protocol)
 | 
			
		||||
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages)
 | 
			
		||||
        hexvectors = FormatFinder.get_hexvectors(bitvectors)
 | 
			
		||||
        address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in
 | 
			
		||||
                                                                        protocol.messages])
 | 
			
		||||
 | 
			
		||||
    def test_find_common_sub_sequence(self):
 | 
			
		||||
        from urh.cythonext import awre_util
 | 
			
		||||
        str1 = "0612345678"
 | 
			
		||||
        str2 = "0756781234"
 | 
			
		||||
 | 
			
		||||
        seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
 | 
			
		||||
        seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
 | 
			
		||||
 | 
			
		||||
        indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2)
 | 
			
		||||
        self.assertEqual(len(indices), 2)
 | 
			
		||||
        for ind in indices:
 | 
			
		||||
            s = str1[slice(*ind)]
 | 
			
		||||
            self.assertIn(s, ("5678", "1234"))
 | 
			
		||||
            self.assertIn(s, str1)
 | 
			
		||||
            self.assertIn(s, str2)
 | 
			
		||||
 | 
			
		||||
    def test_find_first_occurrence(self):
 | 
			
		||||
        from urh.cythonext import awre_util
 | 
			
		||||
        str1 = "00" * 100 + "1234500012345" + "00" * 100
 | 
			
		||||
        str2 = "12345"
 | 
			
		||||
 | 
			
		||||
        seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C")
 | 
			
		||||
        seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C")
 | 
			
		||||
        indices = awre_util.find_occurrences(seq1, seq2)
 | 
			
		||||
        self.assertEqual(len(indices), 2)
 | 
			
		||||
        index = indices[0]
 | 
			
		||||
        self.assertEqual(str1[index:index + len(str2)], str2)
 | 
			
		||||
 | 
			
		||||
        # Test with ignoring indices
 | 
			
		||||
        indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205))))
 | 
			
		||||
        self.assertEqual(len(indices), 1)
 | 
			
		||||
 | 
			
		||||
        # Test with ignoring indices
 | 
			
		||||
        indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210))))
 | 
			
		||||
        self.assertEqual(len(indices), 0)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), [])
 | 
			
		||||
							
								
								
									
										256
									
								
								Software/urh/tests/awre/test_awre_preprocessing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								Software/urh/tests/awre/test_awre_preprocessing.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,256 @@
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.Preprocessor import Preprocessor
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Message import Message
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAWREPreprocessing(AWRETestCase):
 | 
			
		||||
    def test_very_simple_sync_word_finding(self):
 | 
			
		||||
        preamble = "10101010"
 | 
			
		||||
        sync = "1101"
 | 
			
		||||
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
 | 
			
		||||
                                           num_messages=(20,),
 | 
			
		||||
                                           data=(lambda i: 10 * i,))
 | 
			
		||||
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
        #self.save_protocol("very_simple_sync_test", pg)
 | 
			
		||||
        self.assertGreaterEqual(len(possible_syncs), 1)
 | 
			
		||||
        self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
 | 
			
		||||
 | 
			
		||||
    def test_simple_sync_word_finding(self):
 | 
			
		||||
        preamble = "10101010"
 | 
			
		||||
        sync = "1001"
 | 
			
		||||
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "1010", sync)],
 | 
			
		||||
                                           num_messages=(20, 5),
 | 
			
		||||
                                           data=(lambda i: 10 * i, lambda i: 22 * i))
 | 
			
		||||
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
        #self.save_protocol("simple_sync_test", pg)
 | 
			
		||||
        self.assertGreaterEqual(len(possible_syncs), 1)
 | 
			
		||||
        self.assertEqual(preprocessor.find_possible_syncs()[0], sync)
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_odd_preamble(self):
 | 
			
		||||
        preamble = "0101010"
 | 
			
		||||
        sync = "1101"
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
 | 
			
		||||
                                           num_messages=(20, 5),
 | 
			
		||||
                                           data=(lambda i: 10 * i, lambda i: i))
 | 
			
		||||
 | 
			
		||||
        # If we have a odd preamble length, the last bit of the preamble is counted to the sync
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("odd_preamble", pg)
 | 
			
		||||
        self.assertEqual(preamble[-1] + sync[:-1], possible_syncs[0])
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_special_preamble(self):
 | 
			
		||||
        preamble = "111001110011100"
 | 
			
		||||
        sync = "0110"
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
 | 
			
		||||
                                           num_messages=(20, 5),
 | 
			
		||||
                                           data=(lambda i: 10 * i, lambda i: i))
 | 
			
		||||
 | 
			
		||||
        # If we have a odd preamble length, the last bit of the preamble is counted to the sync
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("special_preamble", pg)
 | 
			
		||||
        self.assertEqual(sync, possible_syncs[0])
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_errored_preamble(self):
 | 
			
		||||
        preamble = "00010101010"  # first bits are wrong
 | 
			
		||||
        sync = "0110"
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)],
 | 
			
		||||
                                           num_messages=(20, 5),
 | 
			
		||||
                                           data=(lambda i: 10 * i, lambda i: i))
 | 
			
		||||
 | 
			
		||||
        # If we have a odd preamble length, the last bit of the preamble is counted to the sync
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("errored_preamble", pg)
 | 
			
		||||
        self.assertIn(preamble[-1] + sync[:-1], possible_syncs)
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_with_two_sync_words(self):
 | 
			
		||||
        preamble = "0xaaaa"
 | 
			
		||||
        sync1, sync2 = "0x1234", "0xcafe"
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync1), (preamble, sync2)],
 | 
			
		||||
                                           num_messages=(15, 10),
 | 
			
		||||
                                           data=(lambda i: 12 * i, lambda i: 16 * i))
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
        #self.save_protocol("two_syncs", pg)
 | 
			
		||||
        self.assertGreaterEqual(len(possible_syncs), 2)
 | 
			
		||||
        self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
 | 
			
		||||
        self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
 | 
			
		||||
 | 
			
		||||
    def test_multiple_sync_words(self):
 | 
			
		||||
        hex_messages = [
 | 
			
		||||
            "aaS1234",
 | 
			
		||||
            "aaScafe",
 | 
			
		||||
            "aaSdead",
 | 
			
		||||
            "aaSbeef",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for i in range(1, 256):
 | 
			
		||||
            messages = []
 | 
			
		||||
            sync = "{0:02x}".format(i)
 | 
			
		||||
            if sync.startswith("a"):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            for msg in hex_messages:
 | 
			
		||||
                messages.append(Message.from_plain_hex_str(msg.replace("S", sync)))
 | 
			
		||||
 | 
			
		||||
            for i in range(1, len(messages)):
 | 
			
		||||
                messages[i].message_type = messages[0].message_type
 | 
			
		||||
 | 
			
		||||
            ff = FormatFinder(messages)
 | 
			
		||||
            ff.run()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(len(ff.message_types), 1, msg=sync)
 | 
			
		||||
 | 
			
		||||
            preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
            self.assertEqual(preamble.start, 0, msg=sync)
 | 
			
		||||
            self.assertEqual(preamble.length, 8, msg=sync)
 | 
			
		||||
 | 
			
		||||
            sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
            self.assertEqual(sync.start, 8, msg=sync)
 | 
			
		||||
            self.assertEqual(sync.length, 8, msg=sync)
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_varying_message_length(self):
 | 
			
		||||
        hex_messages = [
 | 
			
		||||
            "aaaa9a7d0f1337471100009a44ebdd13517bf9",
 | 
			
		||||
            "aaaa9a7d4747111337000134a4473c002b909630b11df37e34728c79c60396176aff2b5384e82f31511581d0cbb4822ad1b6734e2372ad5cf4af4c9d6b067e5f7ec359ec443c3b5ddc7a9e",
 | 
			
		||||
            "aaaa9a7d0f13374711000205ee081d26c86b8c",
 | 
			
		||||
            "aaaa9a7d474711133700037cae4cda789885f88f5fb29adc9acf954cb2850b9d94e7f3b009347c466790e89f2bcd728987d4670690861bbaa120f71f14d4ef8dc738a6d7c30e7d2143c267",
 | 
			
		||||
            "aaaa9a7d0f133747110004c2906142300427f3"
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        messages = [Message.from_plain_hex_str(hex_msg) for hex_msg in hex_messages]
 | 
			
		||||
        for i in range(1, len(messages)):
 | 
			
		||||
            messages[i].message_type = messages[0].message_type
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(messages)
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
        self.assertEqual(preamble.start, 0)
 | 
			
		||||
        self.assertEqual(preamble.length, 16)
 | 
			
		||||
 | 
			
		||||
        sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
        self.assertEqual(sync.start, 16)
 | 
			
		||||
        self.assertEqual(sync.length, 16)
 | 
			
		||||
 | 
			
		||||
    def test_sync_word_finding_common_prefix(self):
 | 
			
		||||
        """
 | 
			
		||||
        Messages are very similar (odd and even ones are the same)
 | 
			
		||||
        However, they do not have two different sync words!
 | 
			
		||||
        The algorithm needs to check for a common prefix of the two found sync words
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        sync = "0x1337"
 | 
			
		||||
        num_messages = 10
 | 
			
		||||
 | 
			
		||||
        alice = Participant("Alice", address_hex="dead01")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef24")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("protocol_with_one_message_type")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 72)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 24)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x1337"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 36},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = alice, bob
 | 
			
		||||
                data_length = 8
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = bob, alice
 | 
			
		||||
                data_length = 16
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
 | 
			
		||||
                                source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
        #self.save_protocol("sync_by_common_prefix", pg)
 | 
			
		||||
        self.assertEqual(len(possible_syncs), 1)
 | 
			
		||||
 | 
			
		||||
        # +0000 is okay, because this will get fixed by correction in FormatFinder
 | 
			
		||||
        self.assertIn(possible_syncs[0], [ProtocolGenerator.to_bits(sync), ProtocolGenerator.to_bits(sync) + "0000"])
 | 
			
		||||
 | 
			
		||||
    def test_with_given_preamble_and_sync(self):
 | 
			
		||||
        preamble = "10101010"
 | 
			
		||||
        sync = "10011"
 | 
			
		||||
        pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)],
 | 
			
		||||
                                           num_messages=(20,),
 | 
			
		||||
                                           data=(lambda i: 10 * i,))
 | 
			
		||||
 | 
			
		||||
        # If we have a odd preamble length, the last bit of the preamble is counted to the sync
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages],
 | 
			
		||||
                                    existing_message_types={i: msg.message_type for i, msg in
 | 
			
		||||
                                                            enumerate(pg.protocol.messages)})
 | 
			
		||||
        preamble_starts, preamble_lengths, sync_len = preprocessor.preprocess()
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("given_preamble", pg)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(all(preamble_start == 0 for preamble_start in preamble_starts))
 | 
			
		||||
        self.assertTrue(all(preamble_length == len(preamble) for preamble_length in preamble_lengths))
 | 
			
		||||
        self.assertEqual(sync_len, len(sync))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def build_protocol_generator(preamble_syncs: list, num_messages: tuple, data: tuple) -> ProtocolGenerator:
 | 
			
		||||
        message_types = []
 | 
			
		||||
        preambles_by_mt = dict()
 | 
			
		||||
        syncs_by_mt = dict()
 | 
			
		||||
 | 
			
		||||
        assert len(preamble_syncs) == len(num_messages) == len(data)
 | 
			
		||||
 | 
			
		||||
        for i, (preamble, sync_word) in enumerate(preamble_syncs):
 | 
			
		||||
            assert isinstance(preamble, str)
 | 
			
		||||
            assert isinstance(sync_word, str)
 | 
			
		||||
 | 
			
		||||
            preamble, sync_word = map(ProtocolGenerator.to_bits, (preamble, sync_word))
 | 
			
		||||
 | 
			
		||||
            mb = MessageTypeBuilder("message type #{0}".format(i))
 | 
			
		||||
            mb.add_label(FieldType.Function.PREAMBLE, len(preamble))
 | 
			
		||||
            mb.add_label(FieldType.Function.SYNC, len(sync_word))
 | 
			
		||||
 | 
			
		||||
            message_types.append(mb.message_type)
 | 
			
		||||
            preambles_by_mt[mb.message_type] = preamble
 | 
			
		||||
            syncs_by_mt[mb.message_type] = sync_word
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator(message_types, preambles_by_mt=preambles_by_mt, syncs_by_mt=syncs_by_mt)
 | 
			
		||||
        for i, msg_type in enumerate(message_types):
 | 
			
		||||
            for j in range(num_messages[i]):
 | 
			
		||||
                if callable(data[i]):
 | 
			
		||||
                    msg_data = pg.decimal_to_bits(data[i](j), num_bits=8)
 | 
			
		||||
                else:
 | 
			
		||||
                    msg_data = data[i]
 | 
			
		||||
 | 
			
		||||
                pg.generate_message(message_type=msg_type, data=msg_data)
 | 
			
		||||
 | 
			
		||||
        return pg
 | 
			
		||||
							
								
								
									
										149
									
								
								Software/urh/tests/awre/test_awre_real_protocols.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								Software/urh/tests/awre/test_awre_real_protocols.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,149 @@
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from tests.utils_testing import get_path_for_data_file
 | 
			
		||||
from urh.awre.CommonRange import CommonRange
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.Preprocessor import Preprocessor
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Message import Message
 | 
			
		||||
from urh.signalprocessing.MessageType import MessageType
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
class TestAWRERealProtocols(AWRETestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        alice = Participant("Alice", "A")
 | 
			
		||||
        bob = Participant("Bob", "B")
 | 
			
		||||
        self.participants = [alice, bob]
 | 
			
		||||
 | 
			
		||||
    def test_format_finding_enocean(self):
 | 
			
		||||
        enocean_protocol = ProtocolAnalyzer(None)
 | 
			
		||||
        with open(get_path_for_data_file("enocean_bits.txt")) as f:
 | 
			
		||||
            for line in f:
 | 
			
		||||
                enocean_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", "")))
 | 
			
		||||
                enocean_protocol.messages[-1].message_type = enocean_protocol.default_message_type
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(enocean_protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        message_types = ff.message_types
 | 
			
		||||
        self.assertEqual(len(message_types), 1)
 | 
			
		||||
 | 
			
		||||
        preamble = message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
        self.assertEqual(preamble.start, 0)
 | 
			
		||||
        self.assertEqual(preamble.length, 8)
 | 
			
		||||
 | 
			
		||||
        sync = message_types[0].get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
        self.assertEqual(sync.start, 8)
 | 
			
		||||
        self.assertEqual(sync.length, 4)
 | 
			
		||||
 | 
			
		||||
        checksum = message_types[0].get_first_label_with_type(FieldType.Function.CHECKSUM)
 | 
			
		||||
        self.assertEqual(checksum.start, 56)
 | 
			
		||||
        self.assertEqual(checksum.length, 4)
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
        self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
 | 
			
		||||
 | 
			
		||||
    def test_format_finding_rwe(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("rwe.proto.xml", return_messages=True)
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        sync1, sync2 = "0x9a7d9a7d", "0x67686768"
 | 
			
		||||
 | 
			
		||||
        preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in messages])
 | 
			
		||||
        possible_syncs = preprocessor.find_possible_syncs()
 | 
			
		||||
        self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs)
 | 
			
		||||
        self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs)
 | 
			
		||||
 | 
			
		||||
        ack_messages = (3, 5, 7, 9, 11, 13, 15, 17, 20)
 | 
			
		||||
        ack_message_type = next(mt for mt, messages in ff.existing_message_types.items() if ack_messages[0] in messages)
 | 
			
		||||
        self.assertTrue(all(ack_msg in ff.existing_message_types[ack_message_type] for ack_msg in ack_messages))
 | 
			
		||||
 | 
			
		||||
        for mt in ff.message_types:
 | 
			
		||||
            preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
            self.assertEqual(preamble.start, 0)
 | 
			
		||||
            self.assertEqual(preamble.length, 32)
 | 
			
		||||
 | 
			
		||||
            sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
            self.assertEqual(sync.start, 32)
 | 
			
		||||
            self.assertEqual(sync.length, 32)
 | 
			
		||||
 | 
			
		||||
            length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
            self.assertEqual(length.start, 64)
 | 
			
		||||
            self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
            dst = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
            self.assertEqual(dst.length, 24)
 | 
			
		||||
 | 
			
		||||
            if mt == ack_message_type or 1 in ff.existing_message_types[mt]:
 | 
			
		||||
                self.assertEqual(dst.start, 72)
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertEqual(dst.start, 88)
 | 
			
		||||
 | 
			
		||||
            if mt != ack_message_type and 1 not in ff.existing_message_types[mt]:
 | 
			
		||||
                src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
                self.assertEqual(src.start, 112)
 | 
			
		||||
                self.assertEqual(src.length, 24)
 | 
			
		||||
            elif 1 in ff.existing_message_types[mt]:
 | 
			
		||||
                # long ack
 | 
			
		||||
                src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
                self.assertEqual(src.start, 96)
 | 
			
		||||
                self.assertEqual(src.length, 24)
 | 
			
		||||
 | 
			
		||||
            crc = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
 | 
			
		||||
            self.assertIsNotNone(crc)
 | 
			
		||||
 | 
			
		||||
    def test_homematic(self):
 | 
			
		||||
        proto_file = get_path_for_data_file("homematic.proto.xml")
 | 
			
		||||
        protocol = ProtocolAnalyzer(signal=None, filename=proto_file)
 | 
			
		||||
        protocol.message_types = []
 | 
			
		||||
        protocol.from_xml_file(filename=proto_file, read_bits=True)
 | 
			
		||||
        # prevent interfering with preassinged labels
 | 
			
		||||
        protocol.message_types = [MessageType("Default")]
 | 
			
		||||
 | 
			
		||||
        participants = sorted({msg.participant for msg in protocol.messages})
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
        ff = FormatFinder(protocol.messages, participants=participants)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        self.assertGreater(len(ff.message_types), 0)
 | 
			
		||||
 | 
			
		||||
        for i, message_type in enumerate(ff.message_types):
 | 
			
		||||
            preamble = message_type.get_first_label_with_type(FieldType.Function.PREAMBLE)
 | 
			
		||||
            self.assertEqual(preamble.start, 0)
 | 
			
		||||
            self.assertEqual(preamble.length, 32)
 | 
			
		||||
 | 
			
		||||
            sync = message_type.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
            self.assertEqual(sync.start, 32)
 | 
			
		||||
            self.assertEqual(sync.length, 32)
 | 
			
		||||
 | 
			
		||||
            length = message_type.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
            self.assertEqual(length.start, 64)
 | 
			
		||||
            self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
            seq = message_type.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
            self.assertEqual(seq.start, 72)
 | 
			
		||||
            self.assertEqual(seq.length, 8)
 | 
			
		||||
 | 
			
		||||
            src = message_type.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
            self.assertEqual(src.start, 96)
 | 
			
		||||
            self.assertEqual(src.length, 24)
 | 
			
		||||
 | 
			
		||||
            dst = message_type.get_first_label_with_type(FieldType.Function.DST_ADDRESS)
 | 
			
		||||
            self.assertEqual(dst.start, 120)
 | 
			
		||||
            self.assertEqual(dst.length, 24)
 | 
			
		||||
 | 
			
		||||
            checksum = message_type.get_first_label_with_type(FieldType.Function.CHECKSUM)
 | 
			
		||||
            self.assertEqual(checksum.length, 16)
 | 
			
		||||
            self.assertIn("CC1101", checksum.checksum.caption)
 | 
			
		||||
 | 
			
		||||
            for msg_index in ff.existing_message_types[message_type]:
 | 
			
		||||
                msg_len = len(protocol.messages[msg_index])
 | 
			
		||||
                self.assertEqual(checksum.start, msg_len-16)
 | 
			
		||||
                self.assertEqual(checksum.end, msg_len)
 | 
			
		||||
							
								
								
									
										102
									
								
								Software/urh/tests/awre/test_checksum_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								Software/urh/tests/awre/test_checksum_engine.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
			
		||||
import array
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.CommonRange import ChecksumRange
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.awre.engines.ChecksumEngine import ChecksumEngine
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.util import util
 | 
			
		||||
from urh.util.GenericCRC import GenericCRC
 | 
			
		||||
from urh.cythonext import util as c_util
 | 
			
		||||
 | 
			
		||||
class TestChecksumEngine(AWRETestCase):
 | 
			
		||||
    def test_find_crc8(self):
 | 
			
		||||
        messages = ["aabbcc7d", "abcdee24", "dacafe33"]
 | 
			
		||||
        message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
 | 
			
		||||
 | 
			
		||||
        checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
 | 
			
		||||
        result = checksum_engine.find()
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        checksum_range = result[0]  # type: ChecksumRange
 | 
			
		||||
        self.assertEqual(checksum_range.length, 8)
 | 
			
		||||
        self.assertEqual(checksum_range.start, 24)
 | 
			
		||||
 | 
			
		||||
        reference = GenericCRC()
 | 
			
		||||
        reference.set_polynomial_from_hex("0x07")
 | 
			
		||||
        self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(checksum_range.message_indices, {0, 1, 2})
 | 
			
		||||
 | 
			
		||||
    def test_find_crc16(self):
 | 
			
		||||
        messages = ["12345678347B", "abcdefffABBD", "cafe1337CE12"]
 | 
			
		||||
        message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
 | 
			
		||||
 | 
			
		||||
        checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
 | 
			
		||||
        result = checksum_engine.find()
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        checksum_range = result[0]  # type: ChecksumRange
 | 
			
		||||
        self.assertEqual(checksum_range.start, 32)
 | 
			
		||||
        self.assertEqual(checksum_range.length, 16)
 | 
			
		||||
 | 
			
		||||
        reference = GenericCRC()
 | 
			
		||||
        reference.set_polynomial_from_hex("0x8005")
 | 
			
		||||
        self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(checksum_range.message_indices, {0, 1, 2})
 | 
			
		||||
 | 
			
		||||
    def test_find_crc32(self):
 | 
			
		||||
        messages = ["deadcafe5D7F3F5A", "47111337E3319242", "beefaffe0DCD0E15"]
 | 
			
		||||
        message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)]
 | 
			
		||||
 | 
			
		||||
        checksum_engine = ChecksumEngine(message_bits, n_gram_length=8)
 | 
			
		||||
        result = checksum_engine.find()
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        checksum_range = result[0]  # type: ChecksumRange
 | 
			
		||||
        self.assertEqual(checksum_range.start, 32)
 | 
			
		||||
        self.assertEqual(checksum_range.length, 32)
 | 
			
		||||
 | 
			
		||||
        reference = GenericCRC()
 | 
			
		||||
        reference.set_polynomial_from_hex("0x04C11DB7")
 | 
			
		||||
        self.assertEqual(checksum_range.crc.polynomial, reference.polynomial)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(checksum_range.message_indices, {0, 1, 2})
 | 
			
		||||
 | 
			
		||||
    def test_find_generated_crc16(self):
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DATA, 32)
 | 
			
		||||
        mb.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
 | 
			
		||||
 | 
			
		||||
        mb2 = MessageTypeBuilder("data2")
 | 
			
		||||
        mb2.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb2.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb2.add_label(FieldType.Function.DATA, 16)
 | 
			
		||||
 | 
			
		||||
        mb2.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT"))
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb2.message_type], syncs_by_mt={mb.message_type: "0x1234", mb2.message_type: "0x1234"})
 | 
			
		||||
 | 
			
		||||
        num_messages = 5
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            pg.generate_message(data="{0:032b}".format(i), message_type=mb.message_type)
 | 
			
		||||
            pg.generate_message(data="{0:016b}".format(i), message_type=mb2.message_type)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("crc16_test", pg)
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 2)
 | 
			
		||||
        for mt in ff.message_types:
 | 
			
		||||
            checksum_label = mt.get_first_label_with_type(FieldType.Function.CHECKSUM)
 | 
			
		||||
            self.assertEqual(checksum_label.length, 16)
 | 
			
		||||
            self.assertEqual(checksum_label.checksum.caption, "CRC16 CCITT")
 | 
			
		||||
							
								
								
									
										35
									
								
								Software/urh/tests/awre/test_common_range.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								Software/urh/tests/awre/test_common_range.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from urh.awre.CommonRange import CommonRange
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCommonRange(unittest.TestCase):
 | 
			
		||||
    def test_ensure_not_overlaps(self):
 | 
			
		||||
        test_range = CommonRange(start=4, length=8, value="12345678")
 | 
			
		||||
        self.assertEqual(test_range.end, 11)
 | 
			
		||||
 | 
			
		||||
        # no overlapping
 | 
			
		||||
        self.assertEqual(test_range, test_range.ensure_not_overlaps(0, 3)[0])
 | 
			
		||||
        self.assertEqual(test_range, test_range.ensure_not_overlaps(20, 24)[0])
 | 
			
		||||
 | 
			
		||||
        # overlapping on left
 | 
			
		||||
        result = test_range.ensure_not_overlaps(2, 6)[0]
 | 
			
		||||
        self.assertEqual(result.start, 6)
 | 
			
		||||
        self.assertEqual(result.end, 11)
 | 
			
		||||
 | 
			
		||||
        # overlapping on right
 | 
			
		||||
        result = test_range.ensure_not_overlaps(6, 14)[0]
 | 
			
		||||
        self.assertEqual(result.start, 4)
 | 
			
		||||
        self.assertEqual(result.end, 5)
 | 
			
		||||
 | 
			
		||||
        # full overlapping
 | 
			
		||||
        self.assertEqual(len(test_range.ensure_not_overlaps(3, 14)), 0)
 | 
			
		||||
 | 
			
		||||
        # overlapping in the middle
 | 
			
		||||
        result = test_range.ensure_not_overlaps(6, 9)
 | 
			
		||||
        self.assertEqual(len(result), 2)
 | 
			
		||||
        left, right = result[0], result[1]
 | 
			
		||||
        self.assertEqual(left.start, 4)
 | 
			
		||||
        self.assertEqual(left.end, 5)
 | 
			
		||||
        self.assertEqual(right.start, 10)
 | 
			
		||||
        self.assertEqual(right.end, 11)
 | 
			
		||||
							
								
								
									
										102
									
								
								Software/urh/tests/awre/test_format_finder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								Software/urh/tests/awre/test_format_finder.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.CommonRange import CommonRange, CommonRangeContainer
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFormatFinder(AWRETestCase):
 | 
			
		||||
    def test_create_message_types_1(self):
 | 
			
		||||
        rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
 | 
			
		||||
        rng1.message_indices = {0, 1, 2}
 | 
			
		||||
        rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
 | 
			
		||||
        rng2.message_indices = {0, 1, 2}
 | 
			
		||||
 | 
			
		||||
        message_types = FormatFinder.create_common_range_containers({rng1, rng2})
 | 
			
		||||
        self.assertEqual(len(message_types), 1)
 | 
			
		||||
 | 
			
		||||
        expected = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
 | 
			
		||||
        self.assertEqual(message_types[0], expected)
 | 
			
		||||
 | 
			
		||||
    def test_create_message_types_2(self):
 | 
			
		||||
        rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
 | 
			
		||||
        rng1.message_indices = {0, 2, 4, 6, 8, 12}
 | 
			
		||||
        rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
 | 
			
		||||
        rng2.message_indices = {1, 2, 3, 4, 5, 12}
 | 
			
		||||
        rng3 = CommonRange(16, 8, "1" * 8, score=1, field_type="Seq")
 | 
			
		||||
        rng3.message_indices = {1, 3, 5, 7, 12}
 | 
			
		||||
 | 
			
		||||
        message_types = FormatFinder.create_common_range_containers({rng1, rng2, rng3})
 | 
			
		||||
        expected1 = CommonRangeContainer([rng1], message_indices={0, 6, 8})
 | 
			
		||||
        expected2 = CommonRangeContainer([rng1, rng2], message_indices={2, 4})
 | 
			
		||||
        expected3 = CommonRangeContainer([rng1, rng2, rng3], message_indices={12})
 | 
			
		||||
        expected4 = CommonRangeContainer([rng2, rng3], message_indices={1, 3, 5})
 | 
			
		||||
        expected5 = CommonRangeContainer([rng3], message_indices={7})
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(message_types), 5)
 | 
			
		||||
 | 
			
		||||
        self.assertIn(expected1, message_types)
 | 
			
		||||
        self.assertIn(expected2, message_types)
 | 
			
		||||
        self.assertIn(expected3, message_types)
 | 
			
		||||
        self.assertIn(expected4, message_types)
 | 
			
		||||
        self.assertIn(expected5, message_types)
 | 
			
		||||
 | 
			
		||||
    def test_retransform_message_indices(self):
 | 
			
		||||
        sync_ends = np.array([12, 12, 12, 14, 14])
 | 
			
		||||
 | 
			
		||||
        rng = CommonRange(0, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2, 3, 4})
 | 
			
		||||
        retransformed_ranges = FormatFinder.retransform_message_indices([rng], [0, 1, 2, 3, 4], sync_ends)
 | 
			
		||||
 | 
			
		||||
        # two different sync ends
 | 
			
		||||
        self.assertEqual(len(retransformed_ranges), 2)
 | 
			
		||||
 | 
			
		||||
        expected1 = CommonRange(12, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2})
 | 
			
		||||
        expected2 = CommonRange(14, 8, "1" * 8, score=1, field_type="length", message_indices={3, 4})
 | 
			
		||||
 | 
			
		||||
        self.assertIn(expected1, retransformed_ranges)
 | 
			
		||||
        self.assertIn(expected2, retransformed_ranges)
 | 
			
		||||
 | 
			
		||||
    def test_handle_no_overlapping_conflict(self):
 | 
			
		||||
        rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length")
 | 
			
		||||
        rng1.message_indices = {0, 1, 2}
 | 
			
		||||
        rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address")
 | 
			
		||||
        rng2.message_indices = {0, 1, 2}
 | 
			
		||||
 | 
			
		||||
        container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
 | 
			
		||||
 | 
			
		||||
        # no conflict
 | 
			
		||||
        result = FormatFinder.handle_overlapping_conflict([container])
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        self.assertEqual(len(result[0]), 2)
 | 
			
		||||
        self.assertIn(rng1, result[0])
 | 
			
		||||
        self.assertEqual(result[0].message_indices, {0, 1, 2})
 | 
			
		||||
        self.assertIn(rng2, result[0])
 | 
			
		||||
 | 
			
		||||
    def test_handle_easy_overlapping_conflict(self):
 | 
			
		||||
        # Easy conflict: First Label has higher score
 | 
			
		||||
        rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
 | 
			
		||||
        rng1.message_indices = {0, 1, 2}
 | 
			
		||||
        rng2 = CommonRange(8, 8, "1" * 8, score=0.8, field_type="Address")
 | 
			
		||||
        rng2.message_indices = {0, 1, 2}
 | 
			
		||||
 | 
			
		||||
        container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2})
 | 
			
		||||
        result = FormatFinder.handle_overlapping_conflict([container])
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        self.assertEqual(len(result[0]), 1)
 | 
			
		||||
        self.assertIn(rng1, result[0])
 | 
			
		||||
        self.assertEqual(result[0].message_indices, {0, 1, 2})
 | 
			
		||||
 | 
			
		||||
    def test_handle_medium_overlapping_conflict(self):
 | 
			
		||||
        rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length")
 | 
			
		||||
        rng2 = CommonRange(4, 10, "1" * 8, score=0.8, field_type="Address")
 | 
			
		||||
        rng3 = CommonRange(15, 20, "1" * 8, score=1, field_type="Seq")
 | 
			
		||||
        rng4 = CommonRange(60, 80, "1" * 8, score=0.8, field_type="Type")
 | 
			
		||||
        rng5 = CommonRange(70, 90, "1" * 8, score=0.9, field_type="Data")
 | 
			
		||||
 | 
			
		||||
        container = CommonRangeContainer([rng1, rng2, rng3, rng4, rng5])
 | 
			
		||||
        result = FormatFinder.handle_overlapping_conflict([container])
 | 
			
		||||
        self.assertEqual(len(result), 1)
 | 
			
		||||
        self.assertEqual(len(result[0]), 3)
 | 
			
		||||
        self.assertIn(rng1, result[0])
 | 
			
		||||
        self.assertIn(rng3, result[0])
 | 
			
		||||
        self.assertIn(rng5, result[0])
 | 
			
		||||
							
								
								
									
										236
									
								
								Software/urh/tests/awre/test_generated_protocols.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										236
									
								
								Software/urh/tests/awre/test_generated_protocols.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,236 @@
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre import AutoAssigner
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.Preprocessor import Preprocessor
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.util import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestGeneratedProtocols(AWRETestCase):
 | 
			
		||||
    def __check_addresses(self, messages, format_finder, known_participant_addresses):
 | 
			
		||||
        """
 | 
			
		||||
        Use the AutoAssigner used also in main GUI to test assigned participant addresses to get same results
 | 
			
		||||
        as in main program and not rely on cache of FormatFinder, because values there might be false
 | 
			
		||||
        but SRC address labels still on right position which is the basis for Auto Assigner
 | 
			
		||||
 | 
			
		||||
        :param messages:
 | 
			
		||||
        :param format_finder:
 | 
			
		||||
        :param known_participant_addresses:
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        for msg_type, indices in format_finder.existing_message_types.items():
 | 
			
		||||
            for i in indices:
 | 
			
		||||
                messages[i].message_type = msg_type
 | 
			
		||||
 | 
			
		||||
        participants = list(set(m.participant for m in messages))
 | 
			
		||||
        for p in participants:
 | 
			
		||||
            p.address_hex = ""
 | 
			
		||||
        AutoAssigner.auto_assign_participant_addresses(messages, participants)
 | 
			
		||||
 | 
			
		||||
        for i in range(len(participants)):
 | 
			
		||||
            self.assertIn(participants[i].address_hex,
 | 
			
		||||
                          list(map(util.convert_numbers_to_hex_string, known_participant_addresses.values())),
 | 
			
		||||
                          msg=" [ " + " ".join(p.address_hex for p in participants) + " ]")
 | 
			
		||||
 | 
			
		||||
    def test_without_preamble(self):
 | 
			
		||||
        alice = Participant("Alice", address_hex="24")
 | 
			
		||||
        broadcast = Participant("Broadcast", address_hex="ff")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x8e88"},
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8},
 | 
			
		||||
                               participants=[alice, broadcast])
 | 
			
		||||
 | 
			
		||||
        for i in range(20):
 | 
			
		||||
            data_bits = 16 if i % 2 == 0 else 32
 | 
			
		||||
            source = pg.participants[i % 2]
 | 
			
		||||
            destination = pg.participants[(i + 1) % 2]
 | 
			
		||||
            pg.generate_message(data="1010" * (data_bits // 4), source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("without_preamble", pg)
 | 
			
		||||
        self.clear_message_types(pg.messages)
 | 
			
		||||
        ff = FormatFinder(pg.messages)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
        self.assertEqual(sync.start, 0)
 | 
			
		||||
        self.assertEqual(sync.length, 16)
 | 
			
		||||
 | 
			
		||||
        length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertEqual(length.start, 16)
 | 
			
		||||
        self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
        dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertEqual(dst.start, 24)
 | 
			
		||||
        self.assertEqual(dst.length, 8)
 | 
			
		||||
 | 
			
		||||
        seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(seq.start, 32)
 | 
			
		||||
        self.assertEqual(seq.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_without_preamble_random_data(self):
 | 
			
		||||
        ff = self.get_format_finder_from_protocol_file("without_ack_random_data.proto.xml")
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
        self.assertEqual(sync.start, 0)
 | 
			
		||||
        self.assertEqual(sync.length, 16)
 | 
			
		||||
 | 
			
		||||
        length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertEqual(length.start, 16)
 | 
			
		||||
        self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
        dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertEqual(dst.start, 24)
 | 
			
		||||
        self.assertEqual(dst.length, 8)
 | 
			
		||||
 | 
			
		||||
        seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(seq.start, 32)
 | 
			
		||||
        self.assertEqual(seq.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_without_preamble_random_data2(self):
 | 
			
		||||
        ff = self.get_format_finder_from_protocol_file("without_ack_random_data2.proto.xml")
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        mt = ff.message_types[0]
 | 
			
		||||
        sync = mt.get_first_label_with_type(FieldType.Function.SYNC)
 | 
			
		||||
        self.assertEqual(sync.start, 0)
 | 
			
		||||
        self.assertEqual(sync.length, 16)
 | 
			
		||||
 | 
			
		||||
        length = mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertEqual(length.start, 16)
 | 
			
		||||
        self.assertEqual(length.length, 8)
 | 
			
		||||
 | 
			
		||||
        dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS)
 | 
			
		||||
        self.assertEqual(dst.start, 24)
 | 
			
		||||
        self.assertEqual(dst.length, 8)
 | 
			
		||||
 | 
			
		||||
        seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(seq.start, 32)
 | 
			
		||||
        self.assertEqual(seq.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_with_checksum(self):
 | 
			
		||||
        ff = self.get_format_finder_from_protocol_file("with_checksum.proto.xml", clear_participant_addresses=False)
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertIn(known_participant_addresses[0].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
        self.assertIn(known_participant_addresses[1].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 3)
 | 
			
		||||
 | 
			
		||||
    def test_with_only_one_address(self):
 | 
			
		||||
        ff = self.get_format_finder_from_protocol_file("only_one_address.proto.xml", clear_participant_addresses=False)
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertIn(known_participant_addresses[0].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
        self.assertIn(known_participant_addresses[1].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
 | 
			
		||||
    def test_with_four_broken(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("four_broken.proto.xml",
 | 
			
		||||
                                                                 clear_participant_addresses=False,
 | 
			
		||||
                                                                 return_messages=True)
 | 
			
		||||
 | 
			
		||||
        assert isinstance(ff, FormatFinder)
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.__check_addresses(messages, ff, known_participant_addresses)
 | 
			
		||||
 | 
			
		||||
        for i in range(4, len(messages)):
 | 
			
		||||
            mt = next(mt for mt, indices in ff.existing_message_types.items() if i in indices)
 | 
			
		||||
            self.assertIsNotNone(mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER))
 | 
			
		||||
 | 
			
		||||
    def test_with_one_address_one_message_type(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("one_address_one_mt.proto.xml",
 | 
			
		||||
                                                                 clear_participant_addresses=False,
 | 
			
		||||
                                                                 return_messages=True)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(messages), 17)
 | 
			
		||||
        self.assertEqual(len(ff.hexvectors), 17)
 | 
			
		||||
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        self.assertIn(known_participant_addresses[0].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
        self.assertIn(known_participant_addresses[1].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
 | 
			
		||||
    def test_without_preamble_24_messages(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("no_preamble24.proto.xml",
 | 
			
		||||
                                                                 clear_participant_addresses=False,
 | 
			
		||||
                                                                 return_messages=True)
 | 
			
		||||
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        self.assertIn(known_participant_addresses[0].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
        self.assertIn(known_participant_addresses[1].tostring(),
 | 
			
		||||
                      list(map(bytes, ff.known_participant_addresses.values())))
 | 
			
		||||
 | 
			
		||||
    def test_with_three_syncs_different_preamble_lengths(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("three_syncs.proto.xml", return_messages=True)
 | 
			
		||||
        preprocessor = Preprocessor(ff.get_bitvectors_from_messages(messages))
 | 
			
		||||
        sync_words = preprocessor.find_possible_syncs()
 | 
			
		||||
        self.assertIn("0000010000100000", sync_words, msg="Sync 1")
 | 
			
		||||
        self.assertIn("0010001000100010", sync_words, msg="Sync 2")
 | 
			
		||||
        self.assertIn("0110011101100111", sync_words, msg="Sync 3")
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        expected_sync_ends = [32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24]
 | 
			
		||||
 | 
			
		||||
        for i, (s1, s2) in enumerate(zip(expected_sync_ends, ff.sync_ends)):
 | 
			
		||||
            self.assertEqual(s1, s2, msg=str(i))
 | 
			
		||||
 | 
			
		||||
    def test_with_four_participants(self):
 | 
			
		||||
        ff, messages = self.get_format_finder_from_protocol_file("four_participants.proto.xml",
 | 
			
		||||
                                                                 clear_participant_addresses=False,
 | 
			
		||||
                                                                 return_messages=True)
 | 
			
		||||
 | 
			
		||||
        known_participant_addresses = ff.known_participant_addresses.copy()
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
 | 
			
		||||
        ff.run()
 | 
			
		||||
 | 
			
		||||
        self.__check_addresses(messages, ff, known_participant_addresses)
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 3)
 | 
			
		||||
							
								
								
									
										167
									
								
								Software/urh/tests/awre/test_length_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								Software/urh/tests/awre/test_length_engine.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,167 @@
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.awre.engines.LengthEngine import LengthEngine
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.ProtocoLabel import ProtocolLabel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLengthEngine(AWRETestCase):
 | 
			
		||||
    def test_simple_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a simple protocol with
 | 
			
		||||
        preamble, sync and length field (8 bit) and some random data
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("simple_length_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
 | 
			
		||||
        num_messages_by_data_length = {8: 5, 16: 10, 32: 15}
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"})
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for i in range(num_messages):
 | 
			
		||||
                pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("simple_length", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        length_engine = LengthEngine(ff.bitvectors)
 | 
			
		||||
        highscored_ranges = length_engine.find(n_gram_length=8)
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 3)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertEqual(label.start, 24)
 | 
			
		||||
        self.assertEqual(label.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_easy_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        preamble, sync, sequence number, length field (8 bit) and some random data
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("easy_length_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        num_messages_by_data_length = {32: 10, 64: 15, 16: 5, 24: 7}
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               preambles_by_mt={mb.message_type: "10" * 8},
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0xcafe"})
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for i in range(num_messages):
 | 
			
		||||
                if i % 4 == 0:
 | 
			
		||||
                    data = "1" * data_length
 | 
			
		||||
                elif i % 4 == 1:
 | 
			
		||||
                    data = "0" * data_length
 | 
			
		||||
                elif i % 4 == 2:
 | 
			
		||||
                    data = "10" * (data_length // 2)
 | 
			
		||||
                else:
 | 
			
		||||
                    data = "01" * (data_length // 2)
 | 
			
		||||
 | 
			
		||||
                pg.generate_message(data=data)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("easy_length", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        length_engine = LengthEngine(ff.bitvectors)
 | 
			
		||||
        highscored_ranges = length_engine.find(n_gram_length=8)
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 4)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertIsInstance(label, ProtocolLabel)
 | 
			
		||||
        self.assertEqual(label.start, 32)
 | 
			
		||||
        self.assertEqual(label.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_medium_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Protocol with two message types. Length field only present in one of them
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb1 = MessageTypeBuilder("data")
 | 
			
		||||
        mb1.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb1.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
        mb1.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb1.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        mb2 = MessageTypeBuilder("ack")
 | 
			
		||||
        mb2.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb2.add_label(FieldType.Function.SYNC, 8)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb1.message_type, mb2.message_type],
 | 
			
		||||
                               syncs_by_mt={mb1.message_type: "11110011",
 | 
			
		||||
                                            mb2.message_type: "11110011"})
 | 
			
		||||
        num_messages_by_data_length = {8: 5, 16: 10, 32: 5}
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for i in range(num_messages):
 | 
			
		||||
                pg.generate_message(data=pg.decimal_to_bits(10 * i, data_length), message_type=mb1.message_type)
 | 
			
		||||
                pg.generate_message(message_type=mb2.message_type, data="0xaf")
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("medium_length", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 2)
 | 
			
		||||
        length_mt = next(
 | 
			
		||||
            mt for mt in ff.message_types if mt.get_first_label_with_type(FieldType.Function.LENGTH) is not None)
 | 
			
		||||
        length_label = length_mt.get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
 | 
			
		||||
        for i, sync_end in enumerate(ff.sync_ends):
 | 
			
		||||
            self.assertEqual(sync_end, 16, msg=str(i))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(16, length_label.start)
 | 
			
		||||
        self.assertEqual(8, length_label.length)
 | 
			
		||||
 | 
			
		||||
    def test_little_endian_16_bit(self):
 | 
			
		||||
        mb = MessageTypeBuilder("little_endian_16_length_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages_by_data_length = {256*8: 5, 16: 4, 512: 2}
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"},
 | 
			
		||||
                               little_endian=True)
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for data_length, num_messages in num_messages_by_data_length.items():
 | 
			
		||||
            for i in range(num_messages):
 | 
			
		||||
                pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)]))
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("little_endian_16_length_test", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        length_engine = LengthEngine(ff.bitvectors)
 | 
			
		||||
        highscored_ranges = length_engine.find(n_gram_length=8)
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 3)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)
 | 
			
		||||
        self.assertEqual(label.start, 24)
 | 
			
		||||
        self.assertEqual(label.length, 16)
 | 
			
		||||
							
								
								
									
										198
									
								
								Software/urh/tests/awre/test_partially_labeled.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								Software/urh/tests/awre/test_partially_labeled.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,198 @@
 | 
			
		||||
import copy
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from urh.signalprocessing.MessageType import MessageType
 | 
			
		||||
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPartiallyLabeled(AWRETestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Some tests if there are already information about the message types present
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def test_fully_labeled(self):
 | 
			
		||||
        """
 | 
			
		||||
        For fully labeled protocol, nothing should be done
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        protocol = self.__prepare_example_protocol()
 | 
			
		||||
        message_types = sorted(copy.deepcopy(protocol.message_types), key=lambda x: x.name)
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(message_types), len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        for mt1, mt2 in zip(message_types, ff.message_types):
 | 
			
		||||
            self.assertTrue(self.__message_types_have_same_labels(mt1, mt2))
 | 
			
		||||
 | 
			
		||||
    def test_one_message_type_empty(self):
 | 
			
		||||
        """
 | 
			
		||||
        Empty the "ACK" message type, the labels should be find by FormatFinder
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        protocol = self.__prepare_example_protocol()
 | 
			
		||||
        n_message_types = len(protocol.message_types)
 | 
			
		||||
        ack_mt = next(mt for mt in protocol.message_types if mt.name == "ack")
 | 
			
		||||
        ack_mt.clear()
 | 
			
		||||
        self.assertEqual(len(ack_mt), 0)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(n_message_types, len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ack_mt), 4, msg=str(ack_mt))
 | 
			
		||||
 | 
			
		||||
    def test_given_address_information(self):
 | 
			
		||||
        """
 | 
			
		||||
        Empty both message types and see if addresses are found, when information of participant addresses is given
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        protocol = self.__prepare_example_protocol()
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(2, len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.PREAMBLE))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SYNC))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
 | 
			
		||||
    def test_type_part_already_labeled(self):
 | 
			
		||||
        protocol = self.__prepare_simple_example_protocol()
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        # overlaps type
 | 
			
		||||
        ff.message_types[0].add_protocol_label_start_length(32, 8)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(1, len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
 | 
			
		||||
    def test_length_part_already_labeled(self):
 | 
			
		||||
        protocol = self.__prepare_simple_example_protocol()
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        # overlaps length
 | 
			
		||||
        ff.message_types[0].add_protocol_label_start_length(24, 8)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(1, len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
 | 
			
		||||
        self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
 | 
			
		||||
    def test_address_part_already_labeled(self):
 | 
			
		||||
        protocol = self.__prepare_simple_example_protocol()
 | 
			
		||||
        self.clear_message_types(protocol.messages)
 | 
			
		||||
        ff = FormatFinder(protocol.messages)
 | 
			
		||||
 | 
			
		||||
        # overlaps dst address
 | 
			
		||||
        ff.message_types[0].add_protocol_label_start_length(40, 16)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(1, len(ff.message_types))
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH))
 | 
			
		||||
        self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS))
 | 
			
		||||
        self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __message_types_have_same_labels(mt1: MessageType, mt2: MessageType):
 | 
			
		||||
        if len(mt1) != len(mt2):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        for i, lbl in enumerate(mt1):
 | 
			
		||||
            if lbl != mt2[i]:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def __prepare_example_protocol(self) -> ProtocolAnalyzer:
 | 
			
		||||
        alice = Participant("Alice", "A", address_hex="1234")
 | 
			
		||||
        bob = Participant("Bob", "B", address_hex="cafe")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.TYPE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb_ack = MessageTypeBuilder("ack")
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 50
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type, mb_ack.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = alice, bob
 | 
			
		||||
                data_length = 8
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = bob, alice
 | 
			
		||||
                data_length = 16
 | 
			
		||||
            pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length),
 | 
			
		||||
                                source=source, destination=destination)
 | 
			
		||||
            pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("labeled_protocol", pg)
 | 
			
		||||
 | 
			
		||||
        return pg.protocol
 | 
			
		||||
 | 
			
		||||
    def __prepare_simple_example_protocol(self):
 | 
			
		||||
        random.seed(0)
 | 
			
		||||
        alice = Participant("Alice", "A", address_hex="1234")
 | 
			
		||||
        bob = Participant("Bob", "B", address_hex="cafe")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("data")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.TYPE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x6768"},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(16)]), source=alice, destination=bob)
 | 
			
		||||
            pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(8)]), source=bob, destination=alice)
 | 
			
		||||
 | 
			
		||||
        return pg.protocol
 | 
			
		||||
							
								
								
									
										182
									
								
								Software/urh/tests/awre/test_sequence_number_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								Software/urh/tests/awre/test_sequence_number_engine.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,182 @@
 | 
			
		||||
from tests.awre.AWRETestCase import AWRETestCase
 | 
			
		||||
from urh.awre.CommonRange import CommonRange
 | 
			
		||||
from urh.awre.FormatFinder import FormatFinder
 | 
			
		||||
from urh.awre.MessageTypeBuilder import MessageTypeBuilder
 | 
			
		||||
from urh.awre.ProtocolGenerator import ProtocolGenerator
 | 
			
		||||
from urh.awre.engines.SequenceNumberEngine import SequenceNumberEngine
 | 
			
		||||
from urh.signalprocessing.FieldType import FieldType
 | 
			
		||||
from urh.signalprocessing.Participant import Participant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSequenceNumberEngine(AWRETestCase):
 | 
			
		||||
    def test_simple_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a simple protocol with
 | 
			
		||||
        preamble, sync and increasing sequence number (8 bit) and some constant data
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        mb = MessageTypeBuilder("simple_seq_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8)
 | 
			
		||||
 | 
			
		||||
        num_messages = 20
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"})
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            pg.generate_message(data="0xcafe")
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("simple_sequence_number", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        seq_engine = SequenceNumberEngine(ff.bitvectors, n_gram_length=8)
 | 
			
		||||
        highscored_ranges = seq_engine.find()
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 1)
 | 
			
		||||
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(label.start, 24)
 | 
			
		||||
        self.assertEqual(label.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_16bit_seq_nr(self):
 | 
			
		||||
        mb = MessageTypeBuilder("16bit_seq_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 10
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=64)
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            pg.generate_message(data="0xcafe")
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("16bit_seq", pg)
 | 
			
		||||
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
 | 
			
		||||
        seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
 | 
			
		||||
        highscored_ranges = seq_engine.find()
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 1)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(label.start, 24)
 | 
			
		||||
        self.assertEqual(label.length, 16)
 | 
			
		||||
 | 
			
		||||
    def test_16bit_seq_nr_with_zeros_in_first_part(self):
 | 
			
		||||
        mb = MessageTypeBuilder("16bit_seq_first_byte_zero_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 10
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=1)
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            pg.generate_message(data="0xcafe" + "abc" * i)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("16bit_seq_first_byte_zero_test", pg)
 | 
			
		||||
 | 
			
		||||
        bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages)
 | 
			
		||||
        seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8)
 | 
			
		||||
        highscored_ranges = seq_engine.find()
 | 
			
		||||
        self.assertEqual(len(highscored_ranges), 1)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertGreater(len(ff.message_types[0]), 0)
 | 
			
		||||
        self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
 | 
			
		||||
        # Not consider constants as part of SEQ Nr!
 | 
			
		||||
        self.assertEqual(label.start, 40)
 | 
			
		||||
        self.assertEqual(label.length, 8)
 | 
			
		||||
 | 
			
		||||
    def test_no_sequence_number(self):
 | 
			
		||||
        """
 | 
			
		||||
        Ensure no sequence number is labeled, when it cannot be found
 | 
			
		||||
 | 
			
		||||
        :return:
 | 
			
		||||
        """
 | 
			
		||||
        alice = Participant("Alice", address_hex="dead")
 | 
			
		||||
        bob = Participant("Bob", address_hex="beef")
 | 
			
		||||
 | 
			
		||||
        mb = MessageTypeBuilder("protocol_with_one_message_type")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.LENGTH, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SRC_ADDRESS, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.DST_ADDRESS, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 3
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x1337"},
 | 
			
		||||
                               participants=[alice, bob])
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                source, destination = alice, bob
 | 
			
		||||
            else:
 | 
			
		||||
                source, destination = bob, alice
 | 
			
		||||
            pg.generate_message(data="", source=source, destination=destination)
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("protocol_1", pg)
 | 
			
		||||
 | 
			
		||||
        # Delete message type information -> no prior knowledge
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.known_participant_addresses.clear()
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 0)
 | 
			
		||||
 | 
			
		||||
    def test_sequence_number_little_endian_16_bit(self):
 | 
			
		||||
        mb = MessageTypeBuilder("16bit_seq_test")
 | 
			
		||||
        mb.add_label(FieldType.Function.PREAMBLE, 8)
 | 
			
		||||
        mb.add_label(FieldType.Function.SYNC, 16)
 | 
			
		||||
        mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16)
 | 
			
		||||
 | 
			
		||||
        num_messages = 8
 | 
			
		||||
 | 
			
		||||
        pg = ProtocolGenerator([mb.message_type],
 | 
			
		||||
                               syncs_by_mt={mb.message_type: "0x9a9d"},
 | 
			
		||||
                               little_endian=True, sequence_number_increment=64)
 | 
			
		||||
 | 
			
		||||
        for i in range(num_messages):
 | 
			
		||||
            pg.generate_message(data="0xcafe")
 | 
			
		||||
 | 
			
		||||
        #self.save_protocol("16bit_litte_endian_seq", pg)
 | 
			
		||||
 | 
			
		||||
        self.clear_message_types(pg.protocol.messages)
 | 
			
		||||
        ff = FormatFinder(pg.protocol.messages)
 | 
			
		||||
        ff.perform_iteration()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ff.message_types), 1)
 | 
			
		||||
        self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1)
 | 
			
		||||
        label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)
 | 
			
		||||
        self.assertEqual(label.start, 24)
 | 
			
		||||
        self.assertEqual(label.length, 16)
 | 
			
		||||
		Reference in New Issue
	
	Block a user