# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
from os import pathsep
addr = str(sys.argv[1])
port = str(sys.argv[2])
paths = sys.argv[3]
for p in paths.split(pathsep):
    sys.path.append(p)
from struct import *
import signal
import msgpack
import socket
import traceback
from importlib import import_module
from pathlib import Path
from enum import IntEnum
from io import BytesIO

PROTO_VERSION = 1
HEADER_SZ = 8 + 8 + 1
REAL_HEADER_SZ = 4 + 8 + 8 + 1
FRAMESZ = 32768


class MessageType(IntEnum):
    HELO = 0
    QUIT = 1
    INIT = 2
    INIT_RSP = 3
    CALL = 4
    CALL_RSP = 5
    ERROR = 6


class MessageFlags(IntEnum):
    NORMAL = 0
    INITIAL_REQ = 1
    INITIAL_ACK = 2
    ERROR = 3


class Wrapper(object):
    wrapped_module = None
    wrapped_class = None
    wrapped_fn = None
    sz = None
    mid = None
    rmid = None
    flag = None
    resp = None
    unpacked_msg = None
    msg_type = None
    packer = msgpack.Packer(autoreset=False)
    unpacker = msgpack.Unpacker()
    response_buf = BytesIO()
    stdin_buf = BytesIO()
    wrapped_fns = {}
    alive = True
    readbuf = bytearray(FRAMESZ)
    readview = memoryview(readbuf)


    def init(self, module_name, class_name, fn_name):
        self.wrapped_module = import_module(module_name)
        # do not allow modules to be called that are not part of the uploaded module
        wrapped_fn = None
        if not self.check_module_path(self.wrapped_module):
            self.wrapped_module = None
            raise ImportError("Module was not found in library")
        if class_name is not None:
            self.wrapped_class = getattr(
                import_module(module_name), class_name)()
        if self.wrapped_class is not None:
            wrapped_fn = getattr(self.wrapped_class, fn_name)
        else:
            wrapped_fn = getattr(import_module(module_name), fn_name)
        if wrapped_fn is None:
            raise ImportError(
                "Could not find class or function in specified module")
        self.wrapped_fns[self.mid] = wrapped_fn

    def next_tuple(self, *args, key=None):
        return self.wrapped_fns[key](*args)

    def check_module_path(self, module):
        cwd = Path('.').resolve()
        module_path = Path(module.__file__).resolve()
        return cwd in module_path.parents

    def read_header(self, readbuf):
        self.sz, self.mid, self.rmid, self.flag = unpack(
            "!iqqb", readbuf[0:REAL_HEADER_SZ])
        return True

    def write_header(self, response_buf, dlen):
        total_len = dlen + HEADER_SZ
        header = pack("!iqqb", total_len, int(-1), int(self.rmid), self.flag)
        self.response_buf.write(header)
        return total_len + 4

    def get_ver_hlen(self, hlen):
        return hlen + (PROTO_VERSION << 4)

    def get_hlen(self):
        return self.ver_hlen - (PROTO_VERSION << 4)

    def init_remote_ipc(self):
        self.response_buf.seek(0)
        self.flag = MessageFlags.INITIAL_REQ
        dlen = len(self.unpacked_msg[1])
        resp_len = self.write_header(self.response_buf, dlen)
        self.response_buf.write(self.unpacked_msg[1])
        self.resp = self.response_buf.getbuffer()[0:resp_len]
        self.send_msg()
        self.packer.reset()

    def helo(self):
        # need to ack the connection back before sending actual HELO
        self.init_remote_ipc()
        self.flag = MessageFlags.NORMAL
        self.response_buf.seek(0)
        self.packer.pack(int(MessageType.HELO))
        self.packer.pack("HELO")
        dlen = 5  # tag(1) + body(4)
        resp_len = self.write_header(self.response_buf, dlen)
        self.response_buf.write(self.packer.bytes())
        self.resp = self.response_buf.getbuffer()[0:resp_len]
        self.send_msg()
        self.packer.reset()
        return True

    def handle_init(self):
        self.flag = MessageFlags.NORMAL
        self.response_buf.seek(0)
        args = self.unpacked_msg[1]
        module = args[0]
        if len(args) == 3:
            clazz = args[1]
            fn = args[2]
        else:
            clazz = None
            fn = args[1]
        self.init(module, clazz, fn)
        self.packer.pack(int(MessageType.INIT_RSP))
        dlen = 1  # just the tag.
        resp_len = self.write_header(self.response_buf, dlen)
        self.response_buf.write(self.packer.bytes())
        self.resp = self.response_buf.getbuffer()[0:resp_len]
        self.send_msg()
        self.packer.reset()
        return True

    def quit(self):
        self.alive = False
        self.disconnect_sock()
        return True

    def handle_call(self):
        self.flag = MessageFlags.NORMAL
        result = ([], [])
        if len(self.unpacked_msg) > 1:
            args = self.unpacked_msg[1]
            if args is not None:
                for arg in args:
                    try:
                        result[0].append(self.next_tuple(*arg, key=self.mid))
                    except BaseException as e:
                        result[1].append(traceback.format_exc())
        self.packer.reset()
        self.response_buf.seek(0)
        body = msgpack.packb(result)
        dlen = len(body) + 1  # 1 for tag
        resp_len = self.write_header(self.response_buf, dlen)
        self.packer.pack(int(MessageType.CALL_RSP))
        self.response_buf.write(self.packer.bytes())
        self.response_buf.write(body)
        self.resp = self.response_buf.getbuffer()[0:resp_len]
        self.send_msg()
        self.packer.reset()
        return True

    def handle_error(self, e):
        self.flag = MessageFlags.NORMAL
        self.packer.reset()
        self.response_buf.seek(0)
        body = msgpack.packb(e)
        dlen = len(body) + 1  # 1 for tag
        resp_len = self.write_header(self.response_buf, dlen)
        self.packer.pack(int(MessageType.ERROR))
        self.response_buf.write(self.packer.bytes())
        self.response_buf.write(body)
        self.resp = self.response_buf.getbuffer()[0:resp_len]
        self.send_msg()
        self.packer.reset()
        self.alive = False
        return True

    type_handler = {
        MessageType.HELO: helo,
        MessageType.QUIT: quit,
        MessageType.INIT: handle_init,
        MessageType.CALL: handle_call
    }

    def connect_sock(self, addr, port):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.connect((addr, int(port)))

    def disconnect_sock(self, *args):
        self.sock.shutdown(socket.SHUT_RDWR)
        self.sock.close()

    def recv_msg(self):
        while self.alive:
            pos = sys.stdin.buffer.readinto1(self.readbuf)
            if pos <= 0:
                self.alive = False
                return
            try:
                while pos < REAL_HEADER_SZ:
                    read = sys.stdin.buffer.readinto1(self.readview[pos:])
                    if read <= 0:
                        self.alive = False
                        return
                    pos += read
                self.read_header(self.readview)
                while pos < self.sz and len(self.readbuf) - pos > 0:
                    read = sys.stdin.buffer.readinto1(self.readview[pos:])
                    if read <= 0:
                        self.alive = False
                        return
                    pos += read
                while pos < self.sz:
                    vszchunk = sys.stdin.buffer.read1(FRAMESZ)
                    if len(vszchunk) == 0:
                        self.alive = False
                        return
                    self.readview.release()
                    self.readbuf.extend(vszchunk)
                    self.readview = memoryview(self.readbuf)
                    pos += len(vszchunk)
                self.unpacker.feed(self.readview[REAL_HEADER_SZ:self.sz])
                self.unpacked_msg = list(self.unpacker)
                self.msg_type = MessageType(self.unpacked_msg[0])
                self.type_handler[self.msg_type](self)
            except BaseException:
                self.handle_error(traceback.format_exc())

    def send_msg(self):
        self.sock.sendall(self.resp)
        self.resp = None
        return

    def recv_loop(self):
        while self.alive:
            self.recv_msg()
        self.disconnect_sock()


wrap = Wrapper()
wrap.connect_sock(addr, port)
signal.signal(signal.SIGTERM, wrap.disconnect_sock)
wrap.recv_loop()
