| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- # Copyright 2016 by MPI-SWS and Data-Ken Research.
- # Licensed under the Apache 2.0 License.
- """Common utilities for the tests
- """
- import time
- import unittest
- import random
- random.seed()
- import sys
- import traceback
- import pdb
- from thingflow.base import IterableAsOutputThing, InputThing, FatalError,\
- SensorEvent, Filter
- class RandomSensor:
- def __init__(self, sensor_id, mean=100.0, stddev=20.0, stop_after_events=None):
- self.sensor_id = sensor_id
- self.mean = mean
- self.stddev = stddev
- self.stop_after_events = stop_after_events
- if stop_after_events is not None:
- def generator():
- for i in range(stop_after_events):
- yield random.gauss(mean, stddev)
- else: # go on forever
- def generator():
- while True:
- yield random.gauss(mean, stddev)
- self.generator = generator()
- def sample(self):
- return self.generator.__next__()
- def __repr__(self):
- if self.stop_after_events is None:
- return 'RandomSensor(%s, mean=%s, stddev=%s)' % \
- (self.sensor_id, self.mean, self.stddev)
- else:
- return 'RandomSensor(%s, mean=%s, stddev=%s, stop_after_events=%s)' % \
- (self.sensor_id, self.mean, self.stddev, self.stop_after_events)
- class ValueListSensor:
- def __init__(self, sensor_id, values):
- self.sensor_id = sensor_id
- def generator():
- for v in values:
- yield v
- self.generator = generator()
- def sample(self):
- return self.generator.__next__()
- def __repr__(self):
- return 'ValueListSensor(%s)' % self.sensor_id
- def make_test_output_thing(sensor_id, mean=100.0, stddev=20.0, stop_after_events=None):
- """Here is an exmple test output_thing that generates a random value"""
- if stop_after_events is not None:
- def generator():
- for i in range(stop_after_events):
- yield SensorEvent(sensor_id, time.time(),
- random.gauss(mean, stddev))
- else: # go on forever
- def generator():
- while True:
- yield SensorEvent(sensor_id, time.time(),
- random.gauss(mean, stddev))
- g = generator()
- o = IterableAsOutputThing(g, name='Sensor(%s)' % sensor_id)
- return o
- def make_test_output_thing_from_vallist(sensor_id, values):
- """Create a output_thing that generates the list of values when sampled, but uses
- real timestamps.
- """
- def generator():
- for val in values:
- yield SensorEvent(sensor_id, time.time(), val)
- o = IterableAsOutputThing(generator(), name='Sensor(%s)' % sensor_id)
- return o
- class ValidationInputThing(InputThing):
- """Compare the values in a event stream to the expected values.
- Use the test_case for the assertions (for proper error reporting in a unit
- test).
- """
- def __init__(self, expected_stream, test_case,
- extract_value_fn=lambda event:event.val):
- self.expected_stream = expected_stream
- self.next_idx = 0
- self.test_case = test_case # this can be either a method or a class
- self.extract_value_fn = extract_value_fn
- self.completed = False
- self.name = "ValidationInputThing(%s)" % \
- test_case.__class__.__name__ \
- if isinstance(test_case, unittest.TestCase) \
- else "ValidationInputThing(%s.%s)" % \
- (test_case.__self__.__class__.__name__,
- test_case.__name__)
- def on_next(self, x):
- tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
- else self.test_case.__self__
- tcls.assertLess(self.next_idx, len(self.expected_stream),
- "Got an event after reaching the end of the expected stream")
- expected = self.expected_stream[self.next_idx]
- actual = self.extract_value_fn(x)
- tcls.assertEqual(actual, expected,
- "Values for element %d of event stream mismatch" %
- self.next_idx)
- self.next_idx += 1
- def on_completed(self):
- tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
- else self.test_case.__self__
- tcls.assertEqual(self.next_idx, len(self.expected_stream),
- "Got on_completed() before end of stream")
- self.completed = True
- def on_error(self, exc):
- tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
- else self.test_case.__self__
- tcls.assertTrue(False,
- "Got an unexpected on_error call with parameter: %s" %
- exc)
- def __repr__(self):
- return self.name
-
- class SensorEventValidationInputThing(InputThing):
- """Compare the full events in a sensor event stream to the expected events.
- Use the test_case for the assertions (for proper error reporting in a unit
- test).
- """
- def __init__(self, expected_sensor_events, test_case):
- self.expected_sensor_events = expected_sensor_events
- self.next_idx = 0
- self.test_case = test_case
- self.completed = False
- def on_next(self, x):
- tc = self.test_case
- tc.assertLess(self.next_idx, len(self.expected_sensor_events),
- "Got an event after reaching the end of the expected stream")
- expected = self.expected_sensor_events[self.next_idx]
- actual = x
- tc.assertEqual(actual.val, expected.val,
- "Values for element %d of event stream mismatch" % self.next_idx)
- tc.assertEqual(actual.sensor_id, expected.sensor_id,
- "sensor ids for element %d of event stream mismatch" % self.next_idx)
- # since the timestamp is a floating point number, we only check that
- # the timestamps are "close enough"
- tc.assertAlmostEqual(actual.ts, expected.ts, places=5,
- msg="Timestamps for element %d of event stream mismatch" % self.next_idx)
- self.next_idx += 1
- def on_completed(self):
- tc = self.test_case
- tc.assertEqual(self.next_idx, len(self.expected_sensor_events),
- "Got on_completed() before end of stream")
- self.completed = True
- def on_error(self, exc):
- tc = self.test_case
- tc.assertTrue(False,
- "Got an unexpected on_error call with parameter: %s" % exc)
- class ValidateAndStopInputThing(ValidationInputThing):
- """A version of ValidationInputThing that calls a stop
- function after the specified events have been received.
- """
- def __init__(self, expected_stream, test_case, stop_fn,
- extract_value_fn=lambda event:event.val):
- super().__init__(expected_stream, test_case,
- extract_value_fn=extract_value_fn)
- self.stop_fn = stop_fn
- def on_next(self, x):
- super().on_next(x)
- if self.next_idx==len(self.expected_stream):
- print("ValidateAndStopInputThing: stopping")
- self.stop_fn()
- class CaptureInputThing(InputThing):
- """Capture the sequence of events in a list for later use.
- """
- def __init__(self, expecting_error=False):
- self.events = []
- self.completed = False
- self.expecting_error = expecting_error
- self.errored = False
- def on_next(self, x):
- self.events.append(x)
- def on_completed(self):
- self.completed = True
- def on_error(self, e):
- if self.expecting_error:
- self.errored = True
- else:
- raise FatalError("Should not get on_error, got on_error(%s)" % e)
- class StopAfterN(Filter):
- """Filter to call a stop function after N events.
- Usually, the stop function is the deschedule function for an upstream sensor.
- """
- def __init__(self, previous_in_chain, stop_fn, N=5):
- super().__init__(previous_in_chain)
- self.stop_fn = stop_fn
- self.N = N
- assert N>0
- self.count = 0
-
- def on_next(self, x):
- self._dispatch_next(x)
- self.count += 1
- if self.count==self.N:
- print("stopping after %d events" % self.N)
- self.stop_fn()
- def trace_on_error(f):
- """Decorator helpful when debugging. Will put the decorated function/method
- into the debugger when an exception is thrown
- """
- def decorator(*args, **kwargs):
- try:
- return f(*args, **kwargs)
- except Exception as e:
- info = sys.exc_info()
- traceback.print_exception(*info)
- pdb.post_mortem(info[2])
- return decorator
|