Source code for bokeh.server.views.ws

#-----------------------------------------------------------------------------
# Copyright (c) 2012 - 2023, Anaconda, Inc., and Bokeh Contributors.
# All rights reserved.
#
# The full license is in the file LICENSE.txt, distributed with this software.
#-----------------------------------------------------------------------------
''' Provide a web socket handler for the Bokeh Server application.

'''

#-----------------------------------------------------------------------------
# Boilerplate
#-----------------------------------------------------------------------------
from __future__ import annotations

import logging # isort:skip
log = logging.getLogger(__name__)

#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------

# Standard library imports
import calendar
import datetime as dt
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse

# External imports
from tornado import locks, web
from tornado.websocket import WebSocketClosedError, WebSocketHandler

# Bokeh imports
from bokeh.settings import settings
from bokeh.util.token import check_token_signature, get_session_id, get_token_payload

# Bokeh imports
from ...protocol import Protocol
from ...protocol.exceptions import MessageError, ProtocolError, ValidationError
from ...protocol.message import Message
from ...protocol.receiver import Receiver
from ...util.dataclasses import dataclass
from ..protocol_handler import ProtocolHandler
from .auth_request_handler import AuthRequestHandler

if TYPE_CHECKING:
    from ..connection import ServerConnection

#-----------------------------------------------------------------------------
# Globals and constants
#-----------------------------------------------------------------------------

__all__ = (
    'WSHandler',
)

#-----------------------------------------------------------------------------
# General API
#-----------------------------------------------------------------------------

#-----------------------------------------------------------------------------
# Dev API
#-----------------------------------------------------------------------------

[docs]class WSHandler(AuthRequestHandler, WebSocketHandler): ''' Implements a custom Tornado WebSocketHandler for the Bokeh Server. ''' connection: ServerConnection | None def __init__(self, tornado_app, *args, **kw) -> None: self.receiver = None self.handler = None self.connection = None self.application_context = kw['application_context'] self.latest_pong = -1 # write_lock allows us to lock the connection to send multiple # messages atomically. self.write_lock = locks.Lock() self._token = None self._compression_level = kw.pop('compression_level', None) self._mem_level = kw.pop('mem_level', None) # Note: tornado_app is stored as self.application super().__init__(tornado_app, *args, **kw)
[docs] def initialize(self, application_context, bokeh_websocket_path): pass
[docs] def check_origin(self, origin: str) -> bool: ''' Implement a check_origin policy for Tornado to call. The supplied origin will be compared to the Bokeh server allowlist. If the origin is not allow, an error will be logged and ``False`` will be returned. Args: origin (str) : The URL of the connection origin Returns: bool, True if the connection is allowed, False otherwise ''' from ..util import check_allowlist parsed_origin = urlparse(origin) origin_host = parsed_origin.netloc.lower() allowed_hosts = self.application.websocket_origins if settings.allowed_ws_origin(): allowed_hosts = set(settings.allowed_ws_origin()) allowed = check_allowlist(origin_host, allowed_hosts) if allowed: return True else: log.error("Refusing websocket connection from Origin '%s'; \ use --allow-websocket-origin=%s or set BOKEH_ALLOW_WS_ORIGIN=%s to permit this; currently we allow origins %r", origin, origin_host, origin_host, allowed_hosts) return False
[docs] @web.authenticated def open(self) -> None: ''' Initialize a connection to a client. Returns: None ''' log.info('WebSocket connection opened') token = self._token if self.selected_subprotocol != 'bokeh': self.close() raise ProtocolError("Subprotocol header is not 'bokeh'") elif token is None: self.close() raise ProtocolError("No token received in subprotocol header") now = calendar.timegm(dt.datetime.now(tz=dt.timezone.utc).timetuple()) payload = get_token_payload(token) if 'session_expiry' not in payload: self.close() raise ProtocolError("Session expiry has not been provided") elif now >= payload['session_expiry']: self.close() raise ProtocolError("Token is expired.") elif not check_token_signature(token, signed=self.application.sign_sessions, secret_key=self.application.secret_key): session_id = get_session_id(token) log.error("Token for session %r had invalid signature", session_id) raise ProtocolError("Invalid token signature") try: self.application.io_loop.add_callback(self._async_open, self._token) except Exception as e: # this isn't really an error (unless we have a # bug), it just means a client disconnected # immediately, most likely. log.debug("Failed to fully open connection %r", e)
[docs] def select_subprotocol(self, subprotocols: list[str]) -> str | None: log.debug('Subprotocol header received') log.trace('Supplied subprotocol headers: %r', subprotocols) if not len(subprotocols) == 2: return None self._token = subprotocols[1] return subprotocols[0]
[docs] def get_compression_options(self) -> dict[str, Any] | None: if self._compression_level is None: return None options = {'compression_level': self._compression_level} if self._mem_level is not None: options['mem_level'] = self._mem_level return options
async def _async_open(self, token: str) -> None: ''' Perform the specific steps needed to open a connection to a Bokeh session Specifically, this method coordinates: * Getting a session for a session ID (creating a new one if needed) * Creating a protocol receiver and handler * Opening a new ServerConnection and sending it an ACK Args: session_id (str) : A session ID to for a session to connect to If no session exists with the given ID, a new session is made Returns: None ''' try: session_id = get_session_id(token) await self.application_context.create_session_if_needed(session_id, self.request, token) session = self.application_context.get_session(session_id) protocol = Protocol() self.receiver = Receiver(protocol) log.debug("Receiver created for %r", protocol) self.handler = ProtocolHandler() log.debug("ProtocolHandler created for %r", protocol) self.connection = self.application.new_connection(protocol, self, self.application_context, session) log.info("ServerConnection created") except ProtocolError as e: log.error("Could not create new server session, reason: %s", e) self.close() raise e msg = self.connection.protocol.create('ACK') await self.send_message(msg) return None
[docs] async def on_message(self, fragment: str | bytes) -> None: ''' Process an individual wire protocol fragment. The websocket RFC specifies opcodes for distinguishing text frames from binary frames. Tornado passes us either a text or binary string depending on that opcode, we have to look at the type of the fragment to see what we got. Args: fragment (unicode or bytes) : wire fragment to process ''' # We shouldn't throw exceptions from on_message because the caller is # just Tornado and it doesn't know what to do with them other than # report them as an unhandled Future try: message = await self._receive(fragment) except Exception as e: # If you go look at self._receive, it's catching the # expected error types... here we have something weird. log.error("Unhandled exception receiving a message: %r: %r", e, fragment, exc_info=True) self._internal_error("server failed to parse a message") message = None try: if message: if _message_test_port is not None: _message_test_port.received.append(message) work = await self._handle(message) if work: await self._schedule(work) except Exception as e: log.error("Handler or its work threw an exception: %r: %r", e, message, exc_info=True) self._internal_error("server failed to handle a message") return None
[docs] def on_pong(self, data: bytes) -> None: # if we get an invalid integer or utf-8 back, either we # sent a buggy ping or the client is evil/broken. try: self.latest_pong = int(data.decode("utf-8")) except UnicodeDecodeError: log.trace("received invalid unicode in pong %r", data, exc_info=True) except ValueError: log.trace("received invalid integer in pong %r", data, exc_info=True)
[docs] async def send_message(self, message: Message[Any]) -> None: ''' Send a Bokeh Server protocol message to the connected client. Args: message (Message) : a message to send ''' try: if _message_test_port is not None: _message_test_port.sent.append(message) await message.send(self) except WebSocketClosedError: # on_close() is / will be called anyway log.warning("Failed sending message as connection was closed") return None
[docs] async def write_message(self, message: bytes | str | dict[str, Any], binary: bool = False, locked: bool = True) -> None: ''' Override parent write_message with a version that acquires a write lock before writing. ''' if locked: with await self.write_lock.acquire(): await super().write_message(message, binary) else: await super().write_message(message, binary)
[docs] def on_close(self) -> None: ''' Clean up when the connection is closed. ''' log.info('WebSocket connection closed: code=%s, reason=%r', self.close_code, self.close_reason) if self.connection is not None: self.connection.session.notify_connection_lost() self.application.client_lost(self.connection)
async def _receive(self, fragment: str | bytes) -> Message[Any] | None: # Receive fragments until a complete message is assembled try: message = await self.receiver.consume(fragment) return message except (MessageError, ProtocolError, ValidationError) as e: self._protocol_error(str(e)) return None async def _handle(self, message: Message[Any]) -> Any | None: # Handle the message, possibly resulting in work to do try: work = await self.handler.handle(message, self.connection) return work except (MessageError, ProtocolError, ValidationError) as e: # TODO (other exceptions?) self._internal_error(str(e)) return None async def _schedule(self, work: Any) -> None: if isinstance(work, Message): await self.send_message(cast(Message[Any], work)) else: self._internal_error(f"expected a Message not {work!r}") return None def _internal_error(self, message: str) -> None: log.error("Bokeh Server internal error: %s, closing connection", message) self.close(10000, message) def _protocol_error(self, message: str) -> None: log.error("Bokeh Server protocol error: %s, closing connection", message) self.close(10001, message)
#----------------------------------------------------------------------------- # Private API #----------------------------------------------------------------------------- # This is an undocumented API purely for harvesting low level messages # for testing. When needed it will be set by the testing machinery, and # should not be used for any other purpose. @dataclass class MessageTestPort: sent: list[Message[Any]] received: list[Message[Any]] _message_test_port: MessageTestPort | None = None #----------------------------------------------------------------------------- # Code #-----------------------------------------------------------------------------