kalman_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import asyncio
  2. import numpy as np
  3. from sklearn import linear_model
  4. # For Kalman filtering
  5. from filterpy.kalman import KalmanFilter
  6. from filterpy.common import Q_discrete_white_noise
  7. from thingflow.base import OutputThing, InputThing, from_iterable, Scheduler
  8. class SGDLinearRegressionModel(OutputThing, InputThing):
  9. def __init__(self):
  10. OutputThing.__init__(self, ports=['train', 'observe', 'predict'])
  11. self.clf = linear_model.SGDRegressor()
  12. def on_train_next(self, x):
  13. print("On train next called")
  14. # training input: train the model
  15. xx = np.asarray(x[0])
  16. yy = np.asarray(x[1])
  17. self.clf.partial_fit(xx, yy)
  18. def on_train_error(self, x):
  19. print("On train error called")
  20. self.on_error(x)
  21. def on_train_completed(self):
  22. print("On train completed called")
  23. self.on_completed()
  24. def on_observe_next(self, x):
  25. print("On observe next called")
  26. xx = np.asarray(x)
  27. p = self.clf.predict(xx)
  28. self._dispatch_next(p, port='predict')
  29. def on_observe_error(self, x):
  30. self.on_error(x)
  31. def on_observe_completed(self):
  32. self.on_completed()
  33. class FilterModel(OutputThing, InputThing):
  34. def __init__(self, filter):
  35. OutputThing.__init__(self, ports=['observe', 'predict'])
  36. self.filter = filter
  37. def on_observe_next(self, measurement):
  38. print("On observerain next called")
  39. # training input: train the model
  40. self.filter.predict()
  41. self.filter.update(measurement)
  42. self._dispatch_next(self.filter.x, port='predict')
  43. def on_observe_error(self, x):
  44. print("On observe error called")
  45. self.on_error(x)
  46. def on_observe_completed(self):
  47. print("On observe completed called")
  48. self.on_completed()
  49. class KalmanFilterModel(FilterModel):
  50. """Implements Kalman filters using filterpy.
  51. x' = Fx + Bu + w
  52. y = H x + ww
  53. """
  54. def __init__(self, dim_state, dim_control, dim_measurement,
  55. initial_state_mean, initial_state_covariance,
  56. matrix_F, matrix_B,
  57. process_noise_Q,
  58. matrix_H, measurement_noise_R):
  59. filter = KalmanFilter(dim_x=dim_state, dim_u=dim_control, dim_z=dim_measurement)
  60. filter.x = initial_state_mean
  61. filter.P = initial_state_covariance
  62. filter.Q = process_noise_Q
  63. filter.F = matrix_F
  64. filter.B = matrix_B
  65. filter.H = matrix_H
  66. filter.R = measurement_noise_R # covariance matrix
  67. super().__init__(filter)
  68. def main_linear():
  69. obs_stream = from_iterable(iter([ [ [ [1.0, 1.0], [2.0, 2.0]], [1.0, 2.0] ], [ [ [6.0, 6.0], [9.0, 9.0]], [6.0, 9.0] ] ]))
  70. pred_stream = from_iterable(iter([ [3.0, 3.0] ]))
  71. model = SGDLinearRegressionModel()
  72. obs_stream.connect(model, port_mapping=('default', 'train'))
  73. obs_stream.connect(print)
  74. pred_stream.connect(model, port_mapping=('default', 'observe'))
  75. model.connect(print, port_mapping=('predict', 'default'))
  76. scheduler = Scheduler(asyncio.get_event_loop())
  77. scheduler.schedule_periodic(obs_stream, 1)
  78. scheduler.schedule_periodic(pred_stream, 5)
  79. scheduler.run_forever()
  80. def main_kalman():
  81. dim_x = 2
  82. dim_u = 1
  83. dim_z = 1
  84. initial_state_mean = np.array([ [1.0] , [0.0] ])
  85. initial_state_covariance = 1000 * np.eye(dim_x)
  86. F = np.array([ [ 1., 1.], [0., 1.] ])
  87. B = np.zeros((2, 1) )
  88. Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13)
  89. H = np.array([[1.,0.]])
  90. R = 5 * np.eye(1)
  91. model = KalmanFilterModel(dim_x, dim_u, dim_z, initial_state_mean, initial_state_covariance,
  92. F, B, Q, H, R)
  93. measurement_stream = from_iterable(iter([ [ 1.0 ], [0.0] ]))
  94. # measurement_stream = from_iterable(iter([ np.array([ [1.0, 1.0] ]) ]))
  95. measurement_stream.connect(model, port_mapping=('default', 'observe'))
  96. model.connect(print, port_mapping=('predict', 'default'))
  97. scheduler = Scheduler(asyncio.get_event_loop())
  98. scheduler.schedule_periodic(measurement_stream, 1)
  99. scheduler.run_forever()
  100. def main():
  101. main_kalman()
  102. if __name__ == '__main__':
  103. main()