#-----------------------------------------------------------------------------
# Copyright (c) Anaconda, Inc., and Bokeh Contributors.
# All rights reserved.
#
# The full license is in the file LICENSE.txt, distributed with this software.
#-----------------------------------------------------------------------------
''' Provide some utility functions useful for implementing different
components in ``bokeh.server``.
'''
#-----------------------------------------------------------------------------
# Boilerplate
#-----------------------------------------------------------------------------
from __future__ import annotations
import logging # isort:skip
log = logging.getLogger(__name__)
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
# Standard library imports
from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from socket import socket
# External imports
from tornado import netutil
#-----------------------------------------------------------------------------
# Globals and constants
#-----------------------------------------------------------------------------
__all__ = (
'bind_sockets',
'check_allowlist',
'create_hosts_allowlist',
'match_host',
)
#-----------------------------------------------------------------------------
# General API
#-----------------------------------------------------------------------------
[docs]
def bind_sockets(address: str | None, port: int) -> tuple[list[socket], int]:
''' Bind a socket to a port on an address.
Args:
address (str) :
An address to bind a port on, e.g. ``"localhost"``
port (int) :
A port number to bind.
Pass 0 to have the OS automatically choose a free port.
This function returns a 2-tuple with the new socket as the first element,
and the port that was bound as the second. (Useful when passing 0 as a port
number to bind any free port.)
Returns:
(socket, port)
'''
ss = netutil.bind_sockets(port=port or 0, address=address)
assert len(ss)
ports = {s.getsockname()[1] for s in ss}
assert len(ports) == 1, "Multiple ports assigned??"
actual_port = ports.pop()
if port:
assert actual_port == port
return ss, actual_port
[docs]
def check_allowlist(host: str, allowlist: Sequence[str]) -> bool:
''' Check a given request host against a allowlist.
Args:
host (str) :
A host string to compare against a allowlist.
If the host does not specify a port, then ``":80"`` is implicitly
assumed.
allowlist (seq[str]) :
A list of host patterns to match against
Returns:
``True``, if ``host`` matches any pattern in ``allowlist``, otherwise
``False``
'''
if ':' not in host:
host = host + ':80'
if host in allowlist:
return True
return any(match_host(host, pattern) for pattern in allowlist)
[docs]
def create_hosts_allowlist(host_list: Sequence[str] | None, port: int | None) -> list[str]:
'''
This allowlist can be used to restrict websocket or other connections to
only those explicitly originating from approved hosts.
Args:
host_list (seq[str]) :
A list of string `<name>` or `<name>:<port>` values to add to the
allowlist.
If no port is specified in a host string, then ``":80"`` is
implicitly assumed.
port (int) :
If ``host_list`` is empty or ``None``, then the allowlist will
be the single item list `` [ 'localhost:<port>' ]``
If ``host_list`` is not empty, this parameter has no effect.
Returns:
list[str]
Raises:
ValueError, if host or port values are invalid
Note:
If any host in ``host_list`` contains a wildcard ``*`` a warning will
be logged regarding permissive websocket connections.
'''
if not host_list:
return ['localhost:' + str(port)]
hosts: list[str] = []
for host in host_list:
if '*' in host:
log.warning(
"Host wildcard %r will allow connections originating "
"from multiple (or possibly all) hostnames or IPs. Use non-wildcard "
"values to restrict access explicitly", host)
if host == '*':
# do not append the :80 port suffix in that case: any port is
# accepted
hosts.append(host)
continue
parts = host.split(':')
if len(parts) == 1:
if parts[0] == "":
raise ValueError("Empty host value")
hosts.append(host+":80")
elif len(parts) == 2:
try:
int(parts[1])
except ValueError:
raise ValueError(f"Invalid port in host value: {host}")
if parts[0] == "":
raise ValueError("Empty host value")
hosts.append(host)
else:
raise ValueError(f"Invalid host value: {host}")
return hosts
[docs]
def match_host(host: str, pattern: str) -> bool:
''' Match a host string against a pattern
Args:
host (str)
A hostname to compare to the given pattern
pattern (str)
A string representing a hostname pattern, possibly including
wildcards for ip address octets or ports.
This function will return ``True`` if the hostname matches the pattern,
including any wildcards. If the pattern does not include any wildcards,
then the length the host parts and pattern parts must match identically.
If the pattern contains a port, the host string must also contain a
matching port.
Returns:
bool
Examples:
>>> match_host('192.168.0.1:80', '192.168.0.1:80')
True
>>> match_host('192.168.0.1:80', '192.168.0.1')
True
>>> match_host('192.168.0.1:80', '192.168.0.1:8080')
False
>>> match_host('192.168.0.1', '192.168.0.2')
False
>>> match_host('192.168.0.1', '192.168.*.*')
True
>>> match_host('alice', 'alice')
True
>>> match_host('alice:80', 'alice')
True
>>> match_host('alice', 'bob')
False
>>> match_host('example.com', 'example.com.net')
False
>>> match_host('example.com.bad.com', 'example.com')
False
>>> match_host('alice', '*')
True
>>> match_host('alice', '*:*')
True
>>> match_host('alice:80', '*')
True
>>> match_host('alice:80', '*:80')
True
>>> match_host('alice:8080', '*:80')
False
'''
# This is for a wildcard match without any port restriction
if pattern == "*":
return True
host_port: str | None = None
if ':' in host:
host, host_port = host.rsplit(':', 1)
pattern_port: str | None = None
if ':' in pattern:
pattern, pattern_port = pattern.rsplit(':', 1)
if pattern_port == '*':
pattern_port = None
if pattern_port is not None and host_port != pattern_port:
return False
# This is for a wildcard match including any port restriction
if pattern == "*":
return True
host_parts = host.split('.')
pattern_parts = pattern.split('.')
# since the pattern is not '*', we must enforce that the host and
# pattern have the same number of parts, to avoid matching subdomains
# unintentionally.
if len(pattern_parts) != len(host_parts):
return False
for h, p in zip(host_parts, pattern_parts):
if h == p or p == '*':
continue
else:
return False
return True
#-----------------------------------------------------------------------------
# Dev API
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
# Private API
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
# Code
#-----------------------------------------------------------------------------