diff --git a/mmcli.py b/mmcli.py index 56d21b5..56c8970 100755 --- a/mmcli.py +++ b/mmcli.py @@ -6,10 +6,11 @@ import os import json from typing import Dict, Optional, List import re -# from threading import Lock +from time import sleep +import threading import mattermost -# from mmws import MMws +from mmws import MMws class NotFound(Exception): @@ -143,30 +144,54 @@ def cat(mm_api: mattermost.MMApi, parsed): def attribute(key_value): return key_value - # backlog = [] - # backlog_lock = Lock() + + # In a list to allow overwriting from within print_initial_messages without using global + backlog = [ [] ] + backlog_lock = threading.Lock() + + def print_initial_messages(): + posts = get_posts_for_channel(mm_api, channel["id"], after=parsed.after) + for post in posts: + print(str_for_post(attribute, post, parsed)) + + with backlog_lock: + for post in backlog[0]: + print(str_for_post(attribute, post, parsed)) + backlog[0] = None + if parsed.follow: - raise NotImplementedError("--follow is not yet supported") - # def webs_handler(mmws, event_data): - # if event_data["event"] == "posted": - # with backlog_lock: - # if backlog is not None: - # backlog.append(event_data["data"]) - # return - # print(post_str(attribute, event_data["data"], parsed)) + def simple_websocket_callback(mmws, event_data): + if event_data.get("event") == "posted": + post = json.loads(event_data["data"]["post"]) + if post["channel_id"] != channel["id"]: + return + print(str_for_post(attribute, post, parsed), flush=True) - # ws_url = http_to_ws(mm_api._url) + "/v4/websocket" - # MMws(webs_handler, mm_api, ws_url) - # return + def initial_websocket_callback(mmws: MMws, event_data): + if event_data.get("event") == "posted": + post = json.loads(event_data["data"]["post"]) + if post["channel_id"] != channel["id"]: + return + with backlog_lock: + if backlog[0] is not None: + backlog[0].append(post) + return + else: + mmws.ws_handler = simple_websocket_callback + simple_websocket_callback(mmws, event_data) - posts = get_posts_for_channel(mm_api, channel["id"], after=parsed.after) - for post in posts: - print(post_str(attribute, post, parsed)) + ws_url = http_to_ws(mm_api._url) + "/v4/websocket" + mmws = MMws(initial_websocket_callback, mm_api.access_token, ws_url) - # with backlog_lock: - # for post in backlog: - # print(post_str(attribute, post, parsed)) - # backlog = None + if parsed.follow: + thread = threading.Thread(target=print_initial_messages) + thread.setDaemon(True) + thread.start() + + mmws.run_websocket() + + else: + print_initial_messages() def send(mm_api: mattermost.MMApi, parsed): @@ -200,7 +225,7 @@ def send(mm_api: mattermost.MMApi, parsed): print(sent) -def post_str(attribute, post, parsed): +def str_for_post(attribute, post, parsed): obj = { k: v for k, v in map(attribute, post.items()) @@ -296,6 +321,7 @@ Hint: JSON output can be filtered on the command line with jq(1). server = parsed.server if re.match(r"^[a-z]+://", parsed.server) else f"https://{parsed.server}" mm_api = mattermost.MMApi(f"{server}/api") + mm_api.access_token = access_token if access_token: mm_api._headers.update({"Authorization": f"Bearer {access_token}"}) diff --git a/mmws.py b/mmws.py index cf5531e..3cd46dd 100644 --- a/mmws.py +++ b/mmws.py @@ -1,52 +1,39 @@ import sys import json -import threading import websocket -websocket.enableTrace(True) - class MMws: """ - Websocket client. + Mattermost websocket client """ - def __init__(self, ws_handler, api, ws_url): - self.api = api + def __init__(self, ws_handler, token, ws_url): + """ + @param ws_handler: callback when new data is received on websocket + @param token: Mattermost access token + @param ws_url: websocket URL to connect to + """ + self.token = token self.ws_url = ws_url self.ws_handler = ws_handler self.ws_app = None - self.thread = threading.Thread(target=self._open_websocket) - self.thread.setName("websocket") - self.thread.setDaemon(False) - self.thread.start() - - def _open_websocket(self): + def run_websocket(self): def on_open(ws): - print("Opened") ws.send(json.dumps({ - "seq": 1, "action": "authentication_challenge", "data": {"token": self.api._bearer} + "seq": 1, "action": "authentication_challenge", "data": {"token": self.token} })) def on_message(ws, msg): - print(msg) - self.ws_handler(self, msg) + self.ws_handler(self, json.loads(msg)) def on_error(ws, error): - print(error, file=sys.stderr) - sys.exit(1) + raise error - self.ws_app = websocket.WebSocketApp(self.ws_url, on_open=on_open, on_message=on_message, - on_close=lambda ws: print("Closed")) - print("Start", flush=True) + self.ws_app = websocket.WebSocketApp( + self.ws_url, on_open=on_open, on_message=on_message, on_error=on_error + ) self.ws_app.run_forever() - print("Done", flush=True) - - - - def close_websocket(self): - self.ws_app.close() - self.thread.join()