Coverage for lib/utils/socket.py: 41%

187 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-28 07:25 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3 

4# Hermes : Change Data Capture (CDC) tool from any source(s) to any target 

5# Copyright (C) 2023, 2024 INSA Strasbourg 

6# 

7# This file is part of Hermes. 

8# 

9# Hermes is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# Hermes is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with Hermes. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23import argparse 

24import atexit 

25import grp 

26import logging 

27import os 

28import pwd 

29import socket 

30import threading 

31 

32from stat import S_ISSOCK 

33from time import sleep 

34from typing import Any, Callable, IO 

35 

36from lib.datamodel.serialization import JSONSerializable 

37 

38 

39class InvalidSocketMessageError(Exception): 

40 """Raised when receiving a malformed message on socket""" 

41 

42 

43class SocketNotFoundError(Exception): 

44 """Raised when a client attempt to connect to a non-existent socket file""" 

45 

46 

47class SystemdSocketError(Exception): 

48 """Raised when an error was met when trying to use a socket bound by systemd""" 

49 

50 

51class SocketParsingError(Exception): 

52 """Raised when argparse failed. Converting exception to string will provide argparse 

53 message""" 

54 

55 

56class SocketParsingMessage(Exception): 

57 """Raised when argparse try to print a message. Pass the message in exception 

58 content instead of printing it""" 

59 

60 

61class InvalidOwnerError(Exception): 

62 """Raised when specified socket owner doesn't exist""" 

63 

64 

65class InvalidGroupError(Exception): 

66 """Raised when specified socket group doesn't exist""" 

67 

68 

69class SocketArgumentParser(argparse.ArgumentParser): 

70 """Subclass of argument parser to avoid exiting on error. Will parse arguments 

71 received on server socket""" 

72 

73 def format_error(self, message: str) -> str: 

74 """Format error message""" 

75 return self.format_help() + "\n" + message 

76 

77 def _print_message(self, message: str, file: IO[str] | None = None): 

78 """Override print message to store message in SocketParsingMessage exception 

79 instead of printing it""" 

80 if message: 

81 raise SocketParsingMessage(message) 

82 

83 def exit(self, status=0, message=None): 

84 """Prevent argparser from exiting app""" 

85 pass 

86 

87 def error(self, message: str): 

88 """Raise a SocketParsingError containing error message instead of exiting""" 

89 raise SocketParsingError(self.format_error(message)) 

90 

91 

92class SocketMessageToServer(JSONSerializable): 

93 """Serializable message that SockServer can understand 

94 It is intended to be equivalent to sys.argv""" 

95 

96 def __init__( 

97 self, 

98 argv: list[str] | None = None, 

99 from_json_dict: dict[str, Any] | None = None, 

100 ): 

101 """Create a new message with specified argv list or from deserialized json 

102 dict""" 

103 super().__init__(jsondataattr=["argv"]) 

104 

105 if argv is None and from_json_dict is None: 

106 err = ( 

107 "Cannot instantiante object from nothing:" 

108 " you must specify one data source" 

109 ) 

110 __hermes__.logger.critical(err) 

111 raise AttributeError(err) 

112 

113 if argv is not None and from_json_dict is not None: 

114 err = "Cannot instantiante object from multiple data sources at once" 

115 __hermes__.logger.critical(err) 

116 raise AttributeError(err) 

117 

118 if argv is not None: 

119 self.argv: list[str] = argv 

120 else: 

121 self.argv = from_json_dict["argv"] 

122 

123 if type(self.argv) is not list: 

124 err = f"Invalid type for argv: {type(self.argv)} instead of list" 

125 __hermes__.logger.warning(err) 

126 raise InvalidSocketMessageError(err) 

127 

128 for item in self.argv: 

129 if type(item) is not str: 

130 err = f"Invalid type in argv: {type(item)} instead of str" 

131 __hermes__.logger.warning(err) 

132 raise InvalidSocketMessageError(err) 

133 

134 

135class SocketMessageToClient(JSONSerializable): 

136 """Serializable message (answer) that SockClient can understand 

137 It is intended to be equivalent to a command result with a retcode (0 if no error), 

138 and an output string""" 

139 

140 def __init__( 

141 self, 

142 retcode: int | None = None, 

143 retmsg: str | None = None, 

144 from_json_dict: dict[str, Any] | None = None, 

145 ): 

146 """Create a new message with specified retcode and retmsg, 

147 or from deserialized json dict""" 

148 super().__init__(jsondataattr=["retcode", "retmsg"]) 

149 

150 if (retcode is None or retmsg is None) and from_json_dict is None: 

151 err = ( 

152 "Cannot instantiante object from nothing:" 

153 " you must specify one data source" 

154 ) 

155 __hermes__.logger.critical(err) 

156 raise AttributeError(err) 

157 

158 if (retcode is not None or retmsg is not None) and from_json_dict is not None: 

159 err = "Cannot instantiante object from multiple data sources at once" 

160 __hermes__.logger.critical(err) 

161 raise AttributeError(err) 

162 

163 if retcode is not None: 

164 self.retcode: int = retcode 

165 self.retmsg: str = retmsg 

166 else: 

167 self.retcode = from_json_dict["retcode"] 

168 self.retmsg = from_json_dict["retmsg"] 

169 

170 if type(self.retcode) is not int: 

171 err = f"Invalid type for retcode: {type(self.retcode)} instead of int" 

172 __hermes__.logger.warning(err) 

173 raise InvalidSocketMessageError(err) 

174 

175 if type(self.retmsg) is not str: 

176 err = f"Invalid type for retmsg: {type(self.retmsg)} instead of str" 

177 __hermes__.logger.warning(err) 

178 raise InvalidSocketMessageError(err) 

179 

180 

181class SockServer: 

182 """Create a server awaiting messages on Unix socket, and sending them on a specified 

183 handler at each call of processMessagesInQueue()""" 

184 

185 def __init__( 

186 self, 

187 path: str, 

188 processHdlr: Callable[[SocketMessageToServer], SocketMessageToClient], 

189 owner: str | None = None, 

190 group: str | None = None, 

191 mode: int = 0o0700, 

192 dontManageSockfile: bool = False, 

193 ): 

194 """Create a new server, and its Unix socket on sockpath, with specified mode. 

195 All received messages will be send to specified processHdlr""" 

196 atexit.register(self._cleanup) # Do our best to delete sock file at exit 

197 self._manageSockFile: bool = not dontManageSockfile 

198 self._sockpath: str = path 

199 self._processHdlr: Callable[[SocketMessageToServer], SocketMessageToClient] = ( 

200 processHdlr 

201 ) 

202 self._sock = None 

203 

204 if not self._manageSockFile: 

205 # Ensure one and only one socket was attached by systemd 

206 listen_fds = os.environ.get("LISTEN_FDS") 

207 

208 if listen_fds == "1": 

209 errmsg = None 

210 elif listen_fds is None: 

211 errmsg = ( 

212 "No env var 'LISTEN_FDS' found." 

213 " Unable to use sockfile bound by systemd" 

214 ) 

215 elif listen_fds == "0": 

216 errmsg = ( 

217 "'LISTEN_FDS' env var is '0', indicating that no sockfile was bound" 

218 " by systemd. Check your socket unit file" 

219 ) 

220 else: 

221 errmsg = ( 

222 f"'LISTEN_FDS' env var is '{listen_fds}', indicating that more than" 

223 " one sockfile was bound by systemd. Only one is supported." 

224 " Check your socket unit file" 

225 ) 

226 

227 if errmsg is not None: 

228 raise SystemdSocketError(errmsg) from None 

229 

230 # Attach to existing unix stream socket 

231 self._sock = socket.socket(fileno=3) 

232 self._sock.setblocking(False) 

233 else: 

234 self._removeSocket() # Try to remove the socket if it already exist 

235 

236 # Create a non blocking unix stream socket 

237 self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

238 self._sock.setblocking(False) 

239 

240 # Bind the socket to the specified path 

241 self._sock.bind(self._sockpath) 

242 

243 # Set socket rights as requested 

244 try: 

245 uid = pwd.getpwnam(owner).pw_uid if owner else -1 

246 except KeyError: 

247 raise InvalidOwnerError( 

248 f"Specified socket {owner=} doesn't exists" 

249 ) from None 

250 

251 try: 

252 gid = grp.getgrnam(group).gr_gid if group else -1 

253 except KeyError: 

254 raise InvalidGroupError( 

255 f"Specified socket {group=} doesn't exists" 

256 ) from None 

257 

258 if uid != -1 or gid != -1: 

259 os.chown(self._sockpath, uid, gid) 

260 

261 os.chmod(self._sockpath, mode) 

262 

263 self._sock.listen() # Listen for incoming connections 

264 

265 def _removeSocket(self): 

266 """Try to remove the socket file""" 

267 if not os.path.exists(self._sockpath): 

268 return # Path doesn't exists, nothing to do 

269 

270 # Is path a socket ? 

271 st = os.stat(self._sockpath) 

272 if not S_ISSOCK(st.st_mode): 

273 # Not a socket, raise an exception 

274 raise FileExistsError( 

275 f"The specified path for the unix socket '{self._sockpath}'" 

276 f" already exists and is not a socket" 

277 ) 

278 

279 try: # Is a socket, try to delete it 

280 os.unlink(self._sockpath) 

281 except OSError: 

282 if os.path.exists(self._sockpath): 

283 raise 

284 

285 def _cleanup(self): 

286 """Close the socket and try to remove the socket file""" 

287 if self._sock: 

288 self._sock.close() # Close the socket 

289 

290 if self._manageSockFile: 

291 self._removeSocket() # Try to remove the socket file 

292 

293 def processMessagesInQueue(self): 

294 """Process every message waiting on socket and send them to handler 

295 Returns when no message left""" 

296 while True: 

297 try: 

298 # Check for new incoming connection 

299 connection, client_address = self._sock.accept() 

300 except BlockingIOError: 

301 # __hermes__.logger.debug("No new connection") 

302 break 

303 

304 # Set a reasonnable timeout to prevent blocking whole app if a client 

305 # doesn't close its sending pipe 

306 connection.settimeout(1) 

307 

308 __hermes__.logger.debug("New CLI connection") 

309 # Receive the data 

310 msg = b"" 

311 try: 

312 while True: 

313 data = connection.recv(9999) 

314 if not data: 

315 break # EOF 

316 msg += data 

317 except Exception as e: 

318 __hermes__.logger.warning(f"Got exception during receive: {str(e)}") 

319 else: 

320 # Process message, and generate reply 

321 try: 

322 m = SocketMessageToServer.from_json(msg.decode()) 

323 except InvalidSocketMessageError: 

324 # Ignoring message 

325 pass 

326 else: 

327 reply: SocketMessageToClient = self._processHdlr(m) 

328 try: 

329 connection.sendall(reply.to_json().encode()) # send reply 

330 except Exception as e: 

331 __hermes__.logger.warning( 

332 f"Got exception during send: {str(e)}" 

333 ) 

334 

335 try: 

336 connection.close() 

337 except Exception as e: 

338 __hermes__.logger.warning(f"Got exception during close: {str(e)}") 

339 

340 def startProcessMessagesDaemon(self, appname: str | None = None): 

341 """Will call undefinitly processMessagesInQueue() in a separate thread. 

342 Its to the caller responsability to ensure there will be no race 

343 condition beetween threads 

344 

345 If appname is specified, the daemon loop will fill local thread attributes of 

346 builtin var "__hermes__" at start 

347 """ 

348 if appname: 

349 threadname = f"{appname}-cli-listener" 

350 else: 

351 threadname = None 

352 

353 t = threading.Thread( 

354 target=self.__daemonLoop, 

355 name=threadname, 

356 kwargs={"appname": appname}, 

357 daemon=True, 

358 ) 

359 t.start() 

360 

361 def __daemonLoop(self, appname: str | None = None): 

362 if appname: 

363 __hermes__.appname = appname 

364 __hermes__.logger = logging.getLogger(appname) 

365 

366 while True: 

367 self.processMessagesInQueue() 

368 sleep(0.5) 

369 

370 

371class SockClient: 

372 """Create a client sending a command on Unix socket, and waiting for result""" 

373 

374 @classmethod 

375 def send( 

376 cls, sockpath: str, message: SocketMessageToServer 

377 ) -> SocketMessageToClient: 

378 """Send specified message to server via specified unix sockpath, block until 

379 result is received, and returns it""" 

380 # Create a blocking unix stream socket 

381 with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: 

382 try: 

383 sock.connect(sockpath) # Connect to the socket file 

384 except FileNotFoundError: 

385 raise SocketNotFoundError() 

386 sock.sendall(message.to_json().encode()) # Send message 

387 sock.shutdown(socket.SHUT_WR) # Close the sending pipe 

388 

389 reply = b"" 

390 while True: 

391 data = sock.recv(9999) 

392 if not data: # EOF 

393 return SocketMessageToClient.from_json(reply.decode()) 

394 reply += data