| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sys |
| from os import pathsep |
| addr = str(sys.argv[1]) |
| port = str(sys.argv[2]) |
| paths = sys.argv[3] |
| for p in paths.split(pathsep): |
| sys.path.append(p) |
| from struct import * |
| import signal |
| import msgpack |
| import socket |
| import traceback |
| from importlib import import_module |
| from pathlib import Path |
| from enum import IntEnum |
| from io import BytesIO |
| |
| PROTO_VERSION = 1 |
| HEADER_SZ = 8 + 8 + 1 |
| REAL_HEADER_SZ = 4 + 8 + 8 + 1 |
| FRAMESZ = 32768 |
| |
| |
| class MessageType(IntEnum): |
| HELO = 0 |
| QUIT = 1 |
| INIT = 2 |
| INIT_RSP = 3 |
| CALL = 4 |
| CALL_RSP = 5 |
| ERROR = 6 |
| |
| |
| class MessageFlags(IntEnum): |
| NORMAL = 0 |
| INITIAL_REQ = 1 |
| INITIAL_ACK = 2 |
| ERROR = 3 |
| |
| |
| class Wrapper(object): |
| wrapped_module = None |
| wrapped_class = None |
| wrapped_fn = None |
| sz = None |
| mid = None |
| rmid = None |
| flag = None |
| resp = None |
| unpacked_msg = None |
| msg_type = None |
| packer = msgpack.Packer(autoreset=False) |
| unpacker = msgpack.Unpacker() |
| response_buf = BytesIO() |
| stdin_buf = BytesIO() |
| wrapped_fns = {} |
| alive = True |
| readbuf = bytearray(FRAMESZ) |
| readview = memoryview(readbuf) |
| |
| |
| def init(self, module_name, class_name, fn_name): |
| self.wrapped_module = import_module(module_name) |
| # do not allow modules to be called that are not part of the uploaded module |
| wrapped_fn = None |
| if not self.check_module_path(self.wrapped_module): |
| self.wrapped_module = None |
| raise ImportError("Module was not found in library") |
| if class_name is not None: |
| self.wrapped_class = getattr( |
| import_module(module_name), class_name)() |
| if self.wrapped_class is not None: |
| wrapped_fn = getattr(self.wrapped_class, fn_name) |
| else: |
| wrapped_fn = getattr(import_module(module_name), fn_name) |
| if wrapped_fn is None: |
| raise ImportError( |
| "Could not find class or function in specified module") |
| self.wrapped_fns[self.mid] = wrapped_fn |
| |
| def next_tuple(self, *args, key=None): |
| return self.wrapped_fns[key](*args) |
| |
| def check_module_path(self, module): |
| cwd = Path('.').resolve() |
| module_path = Path(module.__file__).resolve() |
| return cwd in module_path.parents |
| |
| def read_header(self, readbuf): |
| self.sz, self.mid, self.rmid, self.flag = unpack( |
| "!iqqb", readbuf[0:REAL_HEADER_SZ]) |
| return True |
| |
| def write_header(self, response_buf, dlen): |
| total_len = dlen + HEADER_SZ |
| header = pack("!iqqb", total_len, int(-1), int(self.rmid), self.flag) |
| self.response_buf.write(header) |
| return total_len + 4 |
| |
| def get_ver_hlen(self, hlen): |
| return hlen + (PROTO_VERSION << 4) |
| |
| def get_hlen(self): |
| return self.ver_hlen - (PROTO_VERSION << 4) |
| |
| def init_remote_ipc(self): |
| self.response_buf.seek(0) |
| self.flag = MessageFlags.INITIAL_REQ |
| dlen = len(self.unpacked_msg[1]) |
| resp_len = self.write_header(self.response_buf, dlen) |
| self.response_buf.write(self.unpacked_msg[1]) |
| self.resp = self.response_buf.getbuffer()[0:resp_len] |
| self.send_msg() |
| self.packer.reset() |
| |
| def helo(self): |
| # need to ack the connection back before sending actual HELO |
| self.init_remote_ipc() |
| self.flag = MessageFlags.NORMAL |
| self.response_buf.seek(0) |
| self.packer.pack(int(MessageType.HELO)) |
| self.packer.pack("HELO") |
| dlen = 5 # tag(1) + body(4) |
| resp_len = self.write_header(self.response_buf, dlen) |
| self.response_buf.write(self.packer.bytes()) |
| self.resp = self.response_buf.getbuffer()[0:resp_len] |
| self.send_msg() |
| self.packer.reset() |
| return True |
| |
| def handle_init(self): |
| self.flag = MessageFlags.NORMAL |
| self.response_buf.seek(0) |
| args = self.unpacked_msg[1] |
| module = args[0] |
| if len(args) == 3: |
| clazz = args[1] |
| fn = args[2] |
| else: |
| clazz = None |
| fn = args[1] |
| self.init(module, clazz, fn) |
| self.packer.pack(int(MessageType.INIT_RSP)) |
| dlen = 1 # just the tag. |
| resp_len = self.write_header(self.response_buf, dlen) |
| self.response_buf.write(self.packer.bytes()) |
| self.resp = self.response_buf.getbuffer()[0:resp_len] |
| self.send_msg() |
| self.packer.reset() |
| return True |
| |
| def quit(self): |
| self.alive = False |
| return True |
| |
| def handle_call(self): |
| self.flag = MessageFlags.NORMAL |
| result = ([], []) |
| if len(self.unpacked_msg) > 1: |
| args = self.unpacked_msg[1] |
| if args is not None: |
| for arg in args: |
| try: |
| result[0].append(self.next_tuple(*arg, key=self.mid)) |
| except BaseException as e: |
| result[1].append(traceback.format_exc()) |
| self.packer.reset() |
| self.response_buf.seek(0) |
| body = msgpack.packb(result) |
| dlen = len(body) + 1 # 1 for tag |
| resp_len = self.write_header(self.response_buf, dlen) |
| self.packer.pack(int(MessageType.CALL_RSP)) |
| self.response_buf.write(self.packer.bytes()) |
| self.response_buf.write(body) |
| self.resp = self.response_buf.getbuffer()[0:resp_len] |
| self.send_msg() |
| self.packer.reset() |
| return True |
| |
| def handle_error(self, e): |
| self.flag = MessageFlags.NORMAL |
| self.packer.reset() |
| self.response_buf.seek(0) |
| body = msgpack.packb(e) |
| dlen = len(body) + 1 # 1 for tag |
| resp_len = self.write_header(self.response_buf, dlen) |
| self.packer.pack(int(MessageType.ERROR)) |
| self.response_buf.write(self.packer.bytes()) |
| self.response_buf.write(body) |
| self.resp = self.response_buf.getbuffer()[0:resp_len] |
| self.send_msg() |
| self.packer.reset() |
| self.alive = False |
| return True |
| |
| type_handler = { |
| MessageType.HELO: helo, |
| MessageType.QUIT: quit, |
| MessageType.INIT: handle_init, |
| MessageType.CALL: handle_call |
| } |
| |
| def connect_sock(self, addr, port): |
| self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| self.sock.connect((addr, int(port))) |
| |
| def disconnect_sock(self, *args): |
| self.sock.shutdown(socket.SHUT_RDWR) |
| self.sock.close() |
| |
| def recv_msg(self): |
| while self.alive: |
| pos = sys.stdin.buffer.readinto1(self.readbuf) |
| if pos <= 0: |
| self.alive = False |
| return |
| try: |
| while pos < REAL_HEADER_SZ: |
| read = sys.stdin.buffer.readinto1(self.readview[pos:]) |
| if read <= 0: |
| self.alive = False |
| return |
| pos += read |
| self.read_header(self.readview) |
| while pos < self.sz and len(self.readbuf) - pos > 0: |
| read = sys.stdin.buffer.readinto1(self.readview[pos:]) |
| if read <= 0: |
| self.alive = False |
| return |
| pos += read |
| while pos < self.sz: |
| vszchunk = sys.stdin.buffer.read1(FRAMESZ) |
| if len(vszchunk) == 0: |
| self.alive = False |
| return |
| self.readview.release() |
| self.readbuf.extend(vszchunk) |
| self.readview = memoryview(self.readbuf) |
| pos += len(vszchunk) |
| self.unpacker.feed(self.readview[REAL_HEADER_SZ:self.sz]) |
| self.unpacked_msg = list(self.unpacker) |
| self.msg_type = MessageType(self.unpacked_msg[0]) |
| self.type_handler[self.msg_type](self) |
| except BaseException: |
| self.handle_error(traceback.format_exc()) |
| |
| def send_msg(self): |
| self.sock.sendall(self.resp) |
| self.resp = None |
| return |
| |
| def recv_loop(self): |
| while self.alive: |
| self.recv_msg() |
| self.disconnect_sock() |
| |
| |
| wrap = Wrapper() |
| wrap.connect_sock(addr, port) |
| signal.signal(signal.SIGTERM, wrap.disconnect_sock) |
| wrap.recv_loop() |