mqtt.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2016 by MPI-SWS and Data-Ken Research.
  2. # Licensed under the Apache 2.0 License.
  3. import time
  4. from collections import namedtuple
  5. try:
  6. import paho.mqtt.client as paho
  7. except ImportError:
  8. print("could not import paho.mqtt.client")
  9. import ssl
  10. from thingflow.base import InputThing, OutputThing, EventLoopOutputThingMixin
  11. MQTTEvent = namedtuple('MQTTEvent', ['timestamp', 'state', 'mid', 'topic', 'payload', 'qos', 'dup', 'retain' ])
  12. import random
  13. random.seed()
  14. import datetime
  15. class MockMQTTClient(object):
  16. def __init__(self, client_id=""):
  17. self.userdata = None
  18. self.client_id = client_id
  19. self.on_message = None
  20. self.on_connect = None
  21. self.on_publish = None
  22. def connect(self, host, port=1883):
  23. if self.on_connect:
  24. self.on_connect(self, self.userdata, None, 0)
  25. return 0
  26. def subscribe(self, topics):
  27. pass
  28. def publish(self, topic, payload, qos, retain=False):
  29. if self.on_publish:
  30. self.on_publish(self, self.userdata, 0)
  31. def username_pw_set(self, username, password=""):
  32. pass
  33. def loop(self, timeout=1.0, max_packets=1):
  34. s = random.randint(1, max_packets)
  35. for i in range(0, s):
  36. msg = MQTTEvent(datetime.datetime.now(), 0, i, 'bogus/bogus', 'xxx', 0, False, False)
  37. if self.on_message:
  38. self.on_message(self, self.userdata, msg)
  39. time.sleep(timeout)
  40. return 0
  41. def disconnect(self):
  42. pass
  43. class MQTTWriter(InputThing):
  44. """Subscribes to internal events and pushes them out to MQTT.
  45. The topics parameter is a list of (topic, qos) pairs.
  46. Events should be serialized before passing them to the writer.
  47. """
  48. def __init__(self, host, port=1883, client_id="", client_username="", client_password=None, server_tls=False, server_cert=None, topics=[], mock_class=None):
  49. self.host = host
  50. self.port = port
  51. self.client_id = client_id
  52. self.client_username = client_id
  53. self.client_password = client_password
  54. self.topics = topics
  55. self.server_tls = server_tls
  56. self.server_cert = server_cert
  57. if mock_class:
  58. self.client = MockMQTTClient(self.client_id)
  59. else:
  60. self.client = paho.Client(self.client_id)
  61. if self.client_username:
  62. self.client.username_pw_set(self.client_username, password=self.client_password)
  63. self._connect()
  64. def _connect(self):
  65. if self.server_tls:
  66. raise Exception("TBD")
  67. print(self.client.tls_set(self.server_tls.server_cert, cert_reqs=ssl.CERT_OPTIONAL))
  68. print(self.client.connect(self.host, self.port))
  69. else:
  70. self.client.connect(self.host, self.port)
  71. self.client.subscribe(self.topics)
  72. def on_connect(client, userdata, flags, rc):
  73. print("Connected with result code "+str(rc))
  74. self.client.on_connect = on_connect
  75. def on_publish(client, userdata, mid):
  76. print("Successfully published mid %d" % mid)
  77. self.client.on_publish = on_publish
  78. def on_next(self, msg):
  79. """Note that the message is passed directly to paho.mqtt.client. As such,
  80. it must be a string, a bytearray, an int, a float or None. Usually, you would
  81. use something like to_json (in thingflow.filters.json) to do the
  82. serialization of events.
  83. """
  84. # publish the message to the topics
  85. retain = msg.retain if hasattr(msg, 'retain') else False
  86. for (topic, qos) in self.topics:
  87. self.client.publish(topic, msg, qos, retain)
  88. def on_error(self, e):
  89. self.client.disconnect()
  90. def on_completed(self):
  91. self.client.disconnect()
  92. def __str__(self):
  93. return 'MQTTWriter(%s)' % ', '.join([topic for (topic,qos) in self.topics])
  94. class MQTTReader(OutputThing, EventLoopOutputThingMixin):
  95. """An reader that creates a stream from an MQTT broker. Initialize the
  96. reader with a list of topics to subscribe to. The topics parameter
  97. is a list of (topic, qos) pairs.
  98. Pre-requisites: An MQTT broker (on host:port) --- tested with mosquitto
  99. The paho.mqtt python client for mqtt (pip install paho-mqtt)
  100. """
  101. def __init__(self, host, port=1883, client_id="", client_username="", client_password=None, server_tls=False, server_cert=None, topics=[], mock_class=None):
  102. super().__init__()
  103. self.stop_requested = False
  104. self.host = host
  105. self.port = port
  106. self.client_id = client_id
  107. self.client_username = client_id
  108. self.client_password = client_password
  109. self.topics = topics
  110. self.server_tls = server_tls
  111. self.server_cert = server_cert
  112. if mock_class:
  113. self.client = MockMQTTClient(self.client_id)
  114. else:
  115. self.client = paho.Client(self.client_id)
  116. if self.client_username:
  117. self.client.username_pw_set(self.client_username, password=self.client_password)
  118. self._connect()
  119. def on_message(client, userdata, msg):
  120. m = MQTTEvent(msg.timestamp, msg.state, msg.mid, msg.topic, msg.payload, msg.qos, msg.dup, msg.retain)
  121. self._dispatch_next(m)
  122. self.client.on_message = on_message
  123. def _connect(self):
  124. if self.server_tls:
  125. raise Exception("TBD")
  126. print(self.client.tls_set(self.server_tls.server_cert, cert_reqs=ssl.CERT_OPTIONAL))
  127. print(self.client.connect(self.host, self.port))
  128. else:
  129. self.client.connect(self.host, self.port)
  130. def on_connect(client, userdata, flags, rc):
  131. print("Connected with result code "+str(rc))
  132. # Subscribing in on_connect() means that if we lose the connection and
  133. # reconnect then subscriptions will be renewed.
  134. client.subscribe(self.topics)
  135. self.client.on_connect = on_connect
  136. def _observe_event_loop(self):
  137. print("starting event loop")
  138. while True:
  139. if self.stop_requested:
  140. break
  141. result = self.client.loop(1)
  142. if result != 0:
  143. self._connect()
  144. self.stop_requested = False
  145. self.client.disconnect()
  146. print("Stopped private event loop")
  147. def _stop_loop(self):
  148. self.stop_requested = True
  149. print("requesting stop")
  150. def __str__(self):
  151. return 'MQTTReader(%s)' % ', '.join([topic for (topic,qos) in self.topics])