utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2016 by MPI-SWS and Data-Ken Research.
  2. # Licensed under the Apache 2.0 License.
  3. """Common utilities for the tests
  4. """
  5. import time
  6. import unittest
  7. import random
  8. random.seed()
  9. import sys
  10. import traceback
  11. import pdb
  12. from thingflow.base import IterableAsOutputThing, InputThing, FatalError,\
  13. SensorEvent, Filter
  14. class RandomSensor:
  15. def __init__(self, sensor_id, mean=100.0, stddev=20.0, stop_after_events=None):
  16. self.sensor_id = sensor_id
  17. self.mean = mean
  18. self.stddev = stddev
  19. self.stop_after_events = stop_after_events
  20. if stop_after_events is not None:
  21. def generator():
  22. for i in range(stop_after_events):
  23. yield random.gauss(mean, stddev)
  24. else: # go on forever
  25. def generator():
  26. while True:
  27. yield random.gauss(mean, stddev)
  28. self.generator = generator()
  29. def sample(self):
  30. return self.generator.__next__()
  31. def __repr__(self):
  32. if self.stop_after_events is None:
  33. return 'RandomSensor(%s, mean=%s, stddev=%s)' % \
  34. (self.sensor_id, self.mean, self.stddev)
  35. else:
  36. return 'RandomSensor(%s, mean=%s, stddev=%s, stop_after_events=%s)' % \
  37. (self.sensor_id, self.mean, self.stddev, self.stop_after_events)
  38. class ValueListSensor:
  39. def __init__(self, sensor_id, values):
  40. self.sensor_id = sensor_id
  41. def generator():
  42. for v in values:
  43. yield v
  44. self.generator = generator()
  45. def sample(self):
  46. return self.generator.__next__()
  47. def __repr__(self):
  48. return 'ValueListSensor(%s)' % self.sensor_id
  49. def make_test_output_thing(sensor_id, mean=100.0, stddev=20.0, stop_after_events=None):
  50. """Here is an exmple test output_thing that generates a random value"""
  51. if stop_after_events is not None:
  52. def generator():
  53. for i in range(stop_after_events):
  54. yield SensorEvent(sensor_id, time.time(),
  55. random.gauss(mean, stddev))
  56. else: # go on forever
  57. def generator():
  58. while True:
  59. yield SensorEvent(sensor_id, time.time(),
  60. random.gauss(mean, stddev))
  61. g = generator()
  62. o = IterableAsOutputThing(g, name='Sensor(%s)' % sensor_id)
  63. return o
  64. def make_test_output_thing_from_vallist(sensor_id, values):
  65. """Create a output_thing that generates the list of values when sampled, but uses
  66. real timestamps.
  67. """
  68. def generator():
  69. for val in values:
  70. yield SensorEvent(sensor_id, time.time(), val)
  71. o = IterableAsOutputThing(generator(), name='Sensor(%s)' % sensor_id)
  72. return o
  73. class ValidationInputThing(InputThing):
  74. """Compare the values in a event stream to the expected values.
  75. Use the test_case for the assertions (for proper error reporting in a unit
  76. test).
  77. """
  78. def __init__(self, expected_stream, test_case,
  79. extract_value_fn=lambda event:event.val):
  80. self.expected_stream = expected_stream
  81. self.next_idx = 0
  82. self.test_case = test_case # this can be either a method or a class
  83. self.extract_value_fn = extract_value_fn
  84. self.completed = False
  85. self.name = "ValidationInputThing(%s)" % \
  86. test_case.__class__.__name__ \
  87. if isinstance(test_case, unittest.TestCase) \
  88. else "ValidationInputThing(%s.%s)" % \
  89. (test_case.__self__.__class__.__name__,
  90. test_case.__name__)
  91. def on_next(self, x):
  92. tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
  93. else self.test_case.__self__
  94. tcls.assertLess(self.next_idx, len(self.expected_stream),
  95. "Got an event after reaching the end of the expected stream")
  96. expected = self.expected_stream[self.next_idx]
  97. actual = self.extract_value_fn(x)
  98. tcls.assertEqual(actual, expected,
  99. "Values for element %d of event stream mismatch" %
  100. self.next_idx)
  101. self.next_idx += 1
  102. def on_completed(self):
  103. tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
  104. else self.test_case.__self__
  105. tcls.assertEqual(self.next_idx, len(self.expected_stream),
  106. "Got on_completed() before end of stream")
  107. self.completed = True
  108. def on_error(self, exc):
  109. tcls = self.test_case if isinstance(self.test_case, unittest.TestCase)\
  110. else self.test_case.__self__
  111. tcls.assertTrue(False,
  112. "Got an unexpected on_error call with parameter: %s" %
  113. exc)
  114. def __repr__(self):
  115. return self.name
  116. class SensorEventValidationInputThing(InputThing):
  117. """Compare the full events in a sensor event stream to the expected events.
  118. Use the test_case for the assertions (for proper error reporting in a unit
  119. test).
  120. """
  121. def __init__(self, expected_sensor_events, test_case):
  122. self.expected_sensor_events = expected_sensor_events
  123. self.next_idx = 0
  124. self.test_case = test_case
  125. self.completed = False
  126. def on_next(self, x):
  127. tc = self.test_case
  128. tc.assertLess(self.next_idx, len(self.expected_sensor_events),
  129. "Got an event after reaching the end of the expected stream")
  130. expected = self.expected_sensor_events[self.next_idx]
  131. actual = x
  132. tc.assertEqual(actual.val, expected.val,
  133. "Values for element %d of event stream mismatch" % self.next_idx)
  134. tc.assertEqual(actual.sensor_id, expected.sensor_id,
  135. "sensor ids for element %d of event stream mismatch" % self.next_idx)
  136. # since the timestamp is a floating point number, we only check that
  137. # the timestamps are "close enough"
  138. tc.assertAlmostEqual(actual.ts, expected.ts, places=5,
  139. msg="Timestamps for element %d of event stream mismatch" % self.next_idx)
  140. self.next_idx += 1
  141. def on_completed(self):
  142. tc = self.test_case
  143. tc.assertEqual(self.next_idx, len(self.expected_sensor_events),
  144. "Got on_completed() before end of stream")
  145. self.completed = True
  146. def on_error(self, exc):
  147. tc = self.test_case
  148. tc.assertTrue(False,
  149. "Got an unexpected on_error call with parameter: %s" % exc)
  150. class ValidateAndStopInputThing(ValidationInputThing):
  151. """A version of ValidationInputThing that calls a stop
  152. function after the specified events have been received.
  153. """
  154. def __init__(self, expected_stream, test_case, stop_fn,
  155. extract_value_fn=lambda event:event.val):
  156. super().__init__(expected_stream, test_case,
  157. extract_value_fn=extract_value_fn)
  158. self.stop_fn = stop_fn
  159. def on_next(self, x):
  160. super().on_next(x)
  161. if self.next_idx==len(self.expected_stream):
  162. print("ValidateAndStopInputThing: stopping")
  163. self.stop_fn()
  164. class CaptureInputThing(InputThing):
  165. """Capture the sequence of events in a list for later use.
  166. """
  167. def __init__(self, expecting_error=False):
  168. self.events = []
  169. self.completed = False
  170. self.expecting_error = expecting_error
  171. self.errored = False
  172. def on_next(self, x):
  173. self.events.append(x)
  174. def on_completed(self):
  175. self.completed = True
  176. def on_error(self, e):
  177. if self.expecting_error:
  178. self.errored = True
  179. else:
  180. raise FatalError("Should not get on_error, got on_error(%s)" % e)
  181. class StopAfterN(Filter):
  182. """Filter to call a stop function after N events.
  183. Usually, the stop function is the deschedule function for an upstream sensor.
  184. """
  185. def __init__(self, previous_in_chain, stop_fn, N=5):
  186. super().__init__(previous_in_chain)
  187. self.stop_fn = stop_fn
  188. self.N = N
  189. assert N>0
  190. self.count = 0
  191. def on_next(self, x):
  192. self._dispatch_next(x)
  193. self.count += 1
  194. if self.count==self.N:
  195. print("stopping after %d events" % self.N)
  196. self.stop_fn()
  197. def trace_on_error(f):
  198. """Decorator helpful when debugging. Will put the decorated function/method
  199. into the debugger when an exception is thrown
  200. """
  201. def decorator(*args, **kwargs):
  202. try:
  203. return f(*args, **kwargs)
  204. except Exception as e:
  205. info = sys.exc_info()
  206. traceback.print_exception(*info)
  207. pdb.post_mortem(info[2])
  208. return decorator