omegafox/tests/server.py
daijro 68216153f3 Add Playwright tests
Adds Playwright's async tests to ensure that functionality is not broken.
2024-11-03 06:14:39 -06:00

299 lines
9.8 KiB
Python

# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
import contextlib
import gzip
import mimetypes
import os
import socket
import threading
from contextlib import closing
from http import HTTPStatus
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Set,
Tuple,
TypeVar,
cast,
)
from urllib.parse import urlparse
from autobahn.twisted.resource import WebSocketResource
from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol
from OpenSSL import crypto
from twisted.internet import reactor as _twisted_reactor
from twisted.internet import ssl
from twisted.internet.selectreactor import SelectReactor
from twisted.web import http
_dirname = Path(os.path.abspath(__file__)).parent
reactor = cast(SelectReactor, _twisted_reactor)
def find_free_port() -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
T = TypeVar("T")
class ExpectResponse(Generic[T]):
def __init__(self) -> None:
self._value: T
@property
def value(self) -> T:
if not hasattr(self, "_value"):
raise ValueError("no received value")
return self._value
class TestServerRequest(http.Request):
__test__ = False
channel: "TestServerHTTPChannel"
post_body: Optional[bytes] = None
def process(self) -> None:
server = self.channel.factory.server_instance
if self.content:
self.post_body = self.content.read()
self.content.seek(0, 0)
else:
self.post_body = None
uri = urlparse(self.uri.decode())
path = uri.path
request_subscriber = server.request_subscribers.get(path)
if request_subscriber:
request_subscriber._loop.call_soon_threadsafe(request_subscriber.set_result, self)
server.request_subscribers.pop(path)
if path == "/ws":
server._ws_resource.render(self)
return
if server.auth.get(path):
authorization_header = self.requestHeaders.getRawHeaders("authorization")
creds_correct = False
if authorization_header:
creds_correct = server.auth.get(path) == (
self.getUser().decode(),
self.getPassword().decode(),
)
if not creds_correct:
self.setHeader(b"www-authenticate", 'Basic realm="Secure Area"')
self.setResponseCode(HTTPStatus.UNAUTHORIZED)
self.finish()
return
if server.csp.get(path):
self.setHeader(b"Content-Security-Policy", server.csp[path])
if server.routes.get(path):
server.routes[path](self)
return
file_content = None
try:
file_content = (server.static_path / path[1:]).read_bytes()
content_type = mimetypes.guess_type(path)[0]
if content_type and content_type.startswith("text/"):
content_type += "; charset=utf-8"
self.setHeader(b"Content-Type", content_type)
self.setHeader(b"Cache-Control", "no-cache, no-store")
if path in server.gzip_routes:
self.setHeader("Content-Encoding", "gzip")
self.write(gzip.compress(file_content))
else:
self.setHeader(b"Content-Length", str(len(file_content)))
self.write(file_content)
self.setResponseCode(HTTPStatus.OK)
except (FileNotFoundError, IsADirectoryError, PermissionError):
self.setResponseCode(HTTPStatus.NOT_FOUND)
self.finish()
class TestServerHTTPChannel(http.HTTPChannel):
factory: "TestServerFactory"
requestFactory = TestServerRequest
class TestServerFactory(http.HTTPFactory):
server_instance: "Server"
protocol = TestServerHTTPChannel
class Server:
protocol = "http"
def __init__(self) -> None:
self.PORT = find_free_port()
self.EMPTY_PAGE = f"{self.protocol}://localhost:{self.PORT}/empty.html"
self.PREFIX = f"{self.protocol}://localhost:{self.PORT}"
self.CROSS_PROCESS_PREFIX = f"{self.protocol}://127.0.0.1:{self.PORT}"
# On Windows, this list can be empty, reporting text/plain for scripts.
mimetypes.add_type("text/html", ".html")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("image/png", ".png")
mimetypes.add_type("font/woff2", ".woff2")
def __repr__(self) -> str:
return self.PREFIX
@abc.abstractmethod
def listen(self, factory: TestServerFactory) -> None:
pass
def start(self) -> None:
request_subscribers: Dict[str, asyncio.Future] = {}
auth: Dict[str, Tuple[str, str]] = {}
csp: Dict[str, str] = {}
routes: Dict[str, Callable[[TestServerRequest], Any]] = {}
gzip_routes: Set[str] = set()
self.request_subscribers = request_subscribers
self.auth = auth
self.csp = csp
self.routes = routes
self._ws_handlers: List[Callable[["WebSocketProtocol"], None]] = []
self.gzip_routes = gzip_routes
self.static_path = _dirname / "assets"
factory = TestServerFactory()
factory.server_instance = self
ws_factory = WebSocketServerFactory()
ws_factory.protocol = WebSocketProtocol
ws_factory.server_instance = self
self._ws_resource = WebSocketResource(ws_factory)
self.listen(factory)
async def wait_for_request(self, path: str) -> TestServerRequest:
if path in self.request_subscribers:
return await self.request_subscribers[path]
future: asyncio.Future["TestServerRequest"] = asyncio.Future()
self.request_subscribers[path] = future
return await future
@contextlib.contextmanager
def expect_request(self, path: str) -> Generator[ExpectResponse[TestServerRequest], None, None]:
future = asyncio.create_task(self.wait_for_request(path))
cb_wrapper: ExpectResponse[TestServerRequest] = ExpectResponse()
def done_cb(task: asyncio.Task) -> None:
cb_wrapper._value = future.result()
future.add_done_callback(done_cb)
yield cb_wrapper
def set_auth(self, path: str, username: str, password: str) -> None:
self.auth[path] = (username, password)
def set_csp(self, path: str, value: str) -> None:
self.csp[path] = value
def reset(self) -> None:
self.request_subscribers.clear()
self.auth.clear()
self.csp.clear()
self.gzip_routes.clear()
self.routes.clear()
self._ws_handlers.clear()
def set_route(self, path: str, callback: Callable[[TestServerRequest], Any]) -> None:
self.routes[path] = callback
def enable_gzip(self, path: str) -> None:
self.gzip_routes.add(path)
def set_redirect(self, from_: str, to: str) -> None:
def handle_redirect(request: http.Request) -> None:
request.setResponseCode(HTTPStatus.FOUND)
request.setHeader("location", to)
request.finish()
self.set_route(from_, handle_redirect)
def send_on_web_socket_connection(self, data: bytes) -> None:
self.once_web_socket_connection(lambda ws: ws.sendMessage(data))
def once_web_socket_connection(self, handler: Callable[["WebSocketProtocol"], None]) -> None:
self._ws_handlers.append(handler)
class HTTPServer(Server):
def listen(self, factory: http.HTTPFactory) -> None:
reactor.listenTCP(self.PORT, factory, interface="127.0.0.1")
try:
reactor.listenTCP(self.PORT, factory, interface="::1")
except Exception:
pass
class HTTPSServer(Server):
protocol = "https"
def listen(self, factory: http.HTTPFactory) -> None:
cert = ssl.PrivateCertificate.fromCertificateAndKeyPair(
ssl.Certificate.loadPEM((_dirname / "testserver" / "cert.pem").read_bytes()),
ssl.KeyPair.load(
(_dirname / "testserver" / "key.pem").read_bytes(), crypto.FILETYPE_PEM
),
)
contextFactory = cert.options()
reactor.listenSSL(self.PORT, factory, contextFactory, interface="127.0.0.1")
try:
reactor.listenSSL(self.PORT, factory, contextFactory, interface="::1")
except Exception:
pass
class WebSocketProtocol(WebSocketServerProtocol):
def onOpen(self) -> None:
for handler in self.factory.server_instance._ws_handlers.copy():
self.factory.server_instance._ws_handlers.remove(handler)
handler(self)
class TestServer:
def __init__(self) -> None:
self.server = HTTPServer()
self.https_server = HTTPSServer()
def start(self) -> None:
self.server.start()
self.https_server.start()
self.thread = threading.Thread(target=lambda: reactor.run(installSignalHandlers=False))
self.thread.start()
def stop(self) -> None:
reactor.stop()
self.thread.join()
def reset(self) -> None:
self.server.reset()
self.https_server.reset()
test_server = TestServer()