You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

386 lines
14 KiB

import secrets
import warnings
from abc import ABC, abstractmethod
from contextlib import suppress
try:
import cPickle as pickle
except ImportError:
import pickle
import random
from datetime import timedelta as TimeDelta
from typing import Any, Dict, Optional
import msgspec
from flask import Flask, Request, Response
from flask.sessions import SessionInterface as FlaskSessionInterface
from flask.sessions import SessionMixin
from itsdangerous import BadSignature, Signer, want_bytes
from werkzeug.datastructures import CallbackDict
from ._utils import retry_query
from .defaults import Defaults
class ServerSideSession(CallbackDict, SessionMixin):
"""Baseclass for server-side based sessions. This can be accessed through ``flask.session``.
.. attribute:: sid
Session id, internally we use :func:`secrets.token_urlsafe` to generate one
session id.
.. attribute:: modified
When data is changed, this is set to ``True``. Only the session dictionary
itself is tracked; if the session contains mutable data (for example a nested
dict) then this must be set to ``True`` manually when modifying that data. The
session cookie will only be written to the response if this is ``True``.
.. attribute:: accessed
When data is read (or attempted read) or written, this is set to ``True``. Used by
:class:`.ServerSideSessionInterface` to add a ``Vary: Cookie``
header, which allows caching proxies to cache different pages for
different users.
Default is ``False``.
.. attribute:: permanent
This sets and reflects the ``'_permanent'`` key in the dict.
Default is ``False``.
"""
def __bool__(self) -> bool:
return bool(dict(self)) and self.keys() != {"_permanent"}
def __init__(
self,
initial: Optional[Dict[str, Any]] = None,
sid: Optional[str] = None,
permanent: Optional[bool] = None,
):
def on_update(self) -> None:
self.modified = True
self.accessed = True
CallbackDict.__init__(self, initial, on_update)
self.sid = sid
if permanent:
self.permanent = permanent
self.modified = False
self.accessed = False
def __getitem__(self, key: str) -> Any:
self.accessed = True
return super().__getitem__(key)
def get(self, key: str, default: Any = None) -> Any:
self.accessed = True
return super().get(key, default)
def setdefault(self, key: str, default: Any = None) -> Any:
self.accessed = True
return super().setdefault(key, default)
def clear(self) -> None:
"""Clear the session except for the '_permanent' key."""
permanent = self.get("_permanent", False)
super().clear()
self["_permanent"] = permanent
class Serializer(ABC):
"""Baseclass for session serialization."""
@abstractmethod
def decode(self, serialized_data: bytes) -> dict:
"""Deserialize the session data."""
raise NotImplementedError()
@abstractmethod
def encode(self, session: ServerSideSession) -> bytes:
"""Serialize the session data."""
raise NotImplementedError()
class MsgSpecSerializer(Serializer):
def __init__(self, app: Flask, format: str):
self.app: Flask = app
self.encoder: msgspec.msgpack.Encoder or msgspec.json.Encoder
self.decoder: msgspec.msgpack.Decoder or msgspec.json.Decoder
self.alternate_decoder: msgspec.msgpack.Decoder or msgspec.json.Decoder
if format == "msgpack":
self.encoder = msgspec.msgpack.Encoder()
self.decoder = msgspec.msgpack.Decoder()
self.alternate_decoder = msgspec.json.Decoder()
elif format == "json":
self.encoder = msgspec.json.Encoder()
self.decoder = msgspec.json.Decoder()
self.alternate_decoder = msgspec.msgpack.Decoder()
else:
raise ValueError(f"Unsupported serialization format: {format}")
def encode(self, session: ServerSideSession) -> bytes:
"""Serialize the session data."""
try:
return self.encoder.encode(dict(session))
except Exception as e:
self.app.logger.error(f"Failed to serialize session data: {e}")
raise
def decode(self, serialized_data: bytes) -> dict:
"""Deserialize the session data."""
# TODO: Remove the pickle fallback in 1.0.0
with suppress(msgspec.DecodeError):
return self.decoder.decode(serialized_data)
with suppress(msgspec.DecodeError):
return self.alternate_decoder.decode(serialized_data)
with suppress(pickle.UnpicklingError):
return pickle.loads(serialized_data)
# If all decoders fail, raise the final exception
self.app.logger.error("Failed to deserialize session data", exc_info=True)
raise pickle.UnpicklingError("Failed to deserialize session data")
class ServerSideSessionInterface(FlaskSessionInterface, ABC):
"""Used to open a :class:`flask.sessions.ServerSideSessionInterface` instance."""
session_class = ServerSideSession
serializer = None
ttl = True
def __init__(
self,
app: Flask,
key_prefix: str = Defaults.SESSION_KEY_PREFIX,
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_ID_LENGTH,
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
):
self.app = app
self.key_prefix = key_prefix
self.use_signer = use_signer
if use_signer:
warnings.warn(
"The 'use_signer' option is deprecated and will be removed in the next minor release. "
"Please update your configuration accordingly or open an issue.",
DeprecationWarning,
stacklevel=1,
)
self.permanent = permanent
self.sid_length = sid_length
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
self.cleanup_n_requests = cleanup_n_requests
# Cleanup settings for non-TTL databases only
if getattr(self, "ttl", None) is False:
if self.cleanup_n_requests:
self.app.before_request(self._cleanup_n_requests)
else:
self._register_cleanup_app_command()
# Set the serialization format
self.serializer = MsgSpecSerializer(format=serialization_format, app=app)
# INTERNAL METHODS
def _generate_sid(self, session_id_length: int) -> str:
"""Generate a random session id."""
return secrets.token_urlsafe(session_id_length)
# TODO: Remove in 1.0.0
def _get_signer(self, app: Flask) -> Signer:
if not hasattr(app, "secret_key") or not app.secret_key:
raise KeyError("SECRET_KEY must be set when SESSION_USE_SIGNER=True")
return Signer(app.secret_key, salt="flask-session", key_derivation="hmac")
# TODO: Remove in 1.0.0
def _unsign(self, app, sid: str) -> str:
signer = self._get_signer(app)
sid_as_bytes = signer.unsign(sid)
sid = sid_as_bytes.decode()
return sid
# TODO: Remove in 1.0.0
def _sign(self, app, sid: str) -> str:
signer = self._get_signer(app)
sid_as_bytes = want_bytes(sid)
return signer.sign(sid_as_bytes).decode("utf-8")
def _get_store_id(self, sid: str) -> str:
return self.key_prefix + sid
def should_set_storage(self, app: Flask, session: ServerSideSession) -> bool:
"""Used by session backends to determine if session in storage
should be set for this session cookie for this response. If the session
has been modified, the session is set to storage. If
the ``SESSION_REFRESH_EACH_REQUEST`` config is true, the session is
always set to storage. In the second case, this means refreshing the
storage expiry even if the session has not been modified.
.. versionadded:: 0.7.0
"""
return session.modified or app.config["SESSION_REFRESH_EACH_REQUEST"]
# CLEANUP METHODS FOR NON TTL DATABASES
def _register_cleanup_app_command(self):
"""
Register a custom Flask CLI command for cleaning up expired sessions.
Run the command with `flask session_cleanup`. Run with a cron job
or scheduler such as Heroku Scheduler to automatically clean up expired sessions.
"""
@self.app.cli.command("session_cleanup")
def session_cleanup():
with self.app.app_context():
self._delete_expired_sessions()
def _cleanup_n_requests(self) -> None:
"""
Delete expired sessions on average every N requests.
This is less desirable than using the scheduled app command cleanup as it may
slow down some requests but may be useful for rapid development.
"""
if self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0:
self._delete_expired_sessions()
# SECURITY API METHODS
def regenerate(self, session: ServerSideSession) -> None:
"""Regenerate the session id for the given session. Can be used by calling ``flask.session_interface.regenerate()``."""
if session:
# Remove the old session from storage
self._delete_session(self._get_store_id(session.sid))
# Generate a new session ID
new_sid = self._generate_sid(self.sid_length)
session.sid = new_sid
# Mark the session as modified to ensure it gets saved
session.modified = True
# METHODS OVERRIDE FLASK SESSION INTERFACE
def save_session(
self, app: Flask, session: ServerSideSession, response: Response
) -> None:
# Get the domain and path for the cookie from the app
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
name = self.get_cookie_name(app)
# Generate a prefixed session id
store_id = self._get_store_id(session.sid)
# Add a "Vary: Cookie" header if the session was accessed at all.
# This assumes the app is checking the session values in a request that
# behaves differently based on those values. ie. session.get("is_authenticated")
if session.accessed:
response.vary.add("Cookie")
# If the session is empty, do not save it to the database or set a cookie
if not session:
# If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie
if session.modified:
self._delete_session(store_id)
response.delete_cookie(key=name, domain=domain, path=path)
response.vary.add("Cookie")
return
if not self.should_set_storage(app, session):
return
# Update existing or create new session in the database
self._upsert_session(app.permanent_session_lifetime, session, store_id)
if not self.should_set_cookie(app, session):
return
# Get the additional required cookie settings
value = self._sign(app, session.sid) if self.use_signer else session.sid
expires = self.get_expiration_time(app, session)
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
samesite = (
self.get_cookie_samesite(app) if self.has_same_site_capability else None
)
# Set the browser cookie
response.set_cookie(
key=name,
value=value,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
samesite=samesite,
)
response.vary.add("Cookie")
def open_session(self, app: Flask, request: Request) -> ServerSideSession:
# Get the session ID from the cookie
sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"])
# If there's no session ID, generate a new one
if not sid:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)
# If the session ID is signed, unsign it
if self.use_signer:
try:
sid = self._unsign(app, sid)
except BadSignature:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)
# Retrieve the session data from the database
store_id = self._get_store_id(sid)
saved_session_data = self._retrieve_session_data(store_id)
# If the saved session exists, load the session data from the document
if saved_session_data is not None:
return self.session_class(saved_session_data, sid=sid)
# If the saved session does not exist, create a new session
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)
# METHODS TO BE IMPLEMENTED BY SUBCLASSES
@abstractmethod
@retry_query() # use only when retry not supported directly by the client
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
"""Get the saved session from the session storage."""
raise NotImplementedError()
@abstractmethod
@retry_query() # use only when retry not supported directly by the client
def _delete_session(self, store_id: str) -> None:
"""Delete session from the session storage."""
raise NotImplementedError()
@abstractmethod
@retry_query() # use only when retry not supported directly by the client
def _upsert_session(
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str
) -> None:
"""Update existing or create new session in the session storage."""
raise NotImplementedError()
@retry_query() # use only when retry not supported directly by the client
def _delete_expired_sessions(self) -> None:
"""Delete expired sessions from the session storage. Only required for non-TTL databases."""
pass