Source code for wolframclient.utils.externalevaluate

# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, unicode_literals

import ast
import logging
import os
import sys

from wolframclient.deserializers import binary_deserialize
from wolframclient.language import wl
from wolframclient.language.decorators import to_wl
from wolframclient.language.side_effects import side_effect_logger
from wolframclient.serializers import export
from wolframclient.utils import six
from wolframclient.utils.api import zmq
from wolframclient.utils.datastructures import Settings
from wolframclient.utils.encoding import force_text
from wolframclient.utils.functional import last
from wolframclient.utils.importutils import import_string

HIDDEN_VARIABLES = [
    '__loader__', '__builtins__', '__traceback_hidden_variables__',
    'absolute_import', 'print_function', 'unicode_literals'
]

EXPORT_KWARGS = {
    'target_format': 'wxf',
    'allow_external_objects': True,
}


[docs]class UnprintableContext(dict): def __repr__(self): return '<evaluation context>'
[docs]def execute_from_file(path, *args, **opts): with open(path, 'r') as f: return execute_from_string(force_text(f.read()), *args, **opts)
[docs]def execute_from_string(string, context=UnprintableContext()): __traceback_hidden_variables__ = True #this is creating a custom __loader__ that is returning the source code #traceback serializers is inspecting global variables and looking for a standard loader that can return source code. current = UnprintableContext(context or ()) current['__loader__'] = Settings( get_source=lambda module, code=string: code) current['__traceback_hidden_variables__'] = HIDDEN_VARIABLES expressions = list(ast.parse(string).body) if not expressions: return result = None if isinstance(last(expressions), ast.Expr): result = expressions.pop(-1) if expressions: exec(compile(ast.Module(expressions), '', 'exec'), current) if result: result = eval( compile(ast.Expression(result.value), '', 'eval'), current) else: result = wl.Null if context is not None: context.update(current) return result
[docs]class SideEffectSender(logging.Handler):
[docs] def emit(self, record): if isinstance(sys.stdout, StdoutProxy): sys.stdout.send_side_effect(record.msg)
side_effect_logger.addHandler(SideEffectSender())
[docs]class SocketWriter: def __init__(self, socket): self.socket = socket
[docs] def write(self, bytes): self.socket.send(bytes)
[docs]class StdoutProxy: keep_listening = wl.ExternalEvaluate.Private.PythonKeepListening def __init__(self, stream): self.stream = stream self.clear()
[docs] def clear(self): self.current_line = [] self.lines = []
[docs] def write(self, message): messages = force_text(message).split("\n") if len(messages) == 1: self.current_line.extend(messages) else: self.current_line.append(messages.pop(0)) rest = messages.pop(-1) self.lines.extend(messages) self.flush() if rest: self.current_line.append(rest)
[docs] def flush(self): if self.current_line or self.lines: self.send_lines(''.join(self.current_line), *self.lines) self.clear()
[docs] def send_lines(self, *lines): if len(lines) == 1: return self.send_side_effect(wl.Print(*lines)) elif lines: return self.send_side_effect( wl.CompoundExpression(*map(wl.Print, lines)))
[docs] def send_side_effect(self, expr): self.stream.write(export(self.keep_listening(expr), **EXPORT_KWARGS))
[docs]def evaluate_message(context, input=None, return_type=None, function=None, is_module=False, args=None, **opts): __traceback_hidden_variables__ = True if function and args is not None: #then we have a function call to do #first get the function object we need to call if is_module: func = import_string(function) else: func = execute_from_string(function, context) #get the full argument types (possibly calling a serialization function if necessary) #finally call the function and assign the output return func(*args) if isinstance(input, six.string_types): result = execute_from_string(input, context) if return_type == 'string': # bug 354267 repr returns a 'str' even on py2 (i.e. bytes). return force_text(repr(result)) return result
[docs]@to_wl(**EXPORT_KWARGS) def handle_message(socket, context=UnprintableContext()): __traceback_hidden_variables__ = True message = binary_deserialize(socket.recv()) result = evaluate_message(context=context, **message) sys.stdout.flush() return result
[docs]def start_zmq_instance(port=None, write_to_stdout=True, **opts): # make a reply socket sock = zmq.Context.instance().socket(zmq.PAIR) #now bind to a open port on localhost if port: sock.bind('tcp://127.0.0.1:%s' % port) else: sock.bind_to_random_port('tcp://127.0.0.1') if write_to_stdout: sys.stdout.write(force_text(sock.getsockopt(zmq.LAST_ENDPOINT))) sys.stdout.write(os.linesep) #writes \n sys.stdout.flush() return sock
[docs]def start_zmq_loop(message_limit=float('inf'), redirect_stdout=True, **opts): socket = start_zmq_instance(**opts) stream = SocketWriter(socket) messages = 0 if redirect_stdout: sys.stdout = StdoutProxy(stream) #now sit in a while loop, evaluating input while messages < message_limit: stream.write(handle_message(socket)) messages += 1 if redirect_stdout: sys.stdout = sys.__stdout__