You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
6.8 KiB
190 lines
6.8 KiB
import warnings
|
|
from datetime import datetime
|
|
from datetime import timedelta as TimeDelta
|
|
from typing import Any, Optional
|
|
|
|
from flask import Flask
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from itsdangerous import want_bytes
|
|
from sqlalchemy import Column, DateTime, Integer, LargeBinary, Sequence, String
|
|
|
|
from .._utils import retry_query
|
|
from ..base import ServerSideSession, ServerSideSessionInterface
|
|
from ..defaults import Defaults
|
|
|
|
|
|
class SqlAlchemySession(ServerSideSession):
|
|
pass
|
|
|
|
|
|
def create_session_model(db, table_name, schema=None, bind_key=None, sequence=None):
|
|
class Session(db.Model):
|
|
__tablename__ = table_name
|
|
__table_args__ = {"schema": schema} if schema else {}
|
|
__bind_key__ = bind_key
|
|
|
|
id = (
|
|
Column(Integer, Sequence(sequence), primary_key=True)
|
|
if sequence
|
|
else Column(Integer, primary_key=True)
|
|
)
|
|
session_id = Column(String(255), unique=True)
|
|
data = Column(LargeBinary)
|
|
expiry = Column(DateTime)
|
|
|
|
def __init__(self, session_id: str, data: Any, expiry: datetime):
|
|
self.session_id = session_id
|
|
self.data = data
|
|
self.expiry = expiry
|
|
|
|
def __repr__(self):
|
|
return f"<Session data {self.data}>"
|
|
|
|
return Session
|
|
|
|
|
|
class SqlAlchemySessionInterface(ServerSideSessionInterface):
|
|
"""Uses the Flask-SQLAlchemy from a flask app as session storage.
|
|
|
|
:param app: A Flask app instance.
|
|
:param client: A Flask-SQLAlchemy instance.
|
|
:param key_prefix: A prefix that is added to all storage keys.
|
|
:param use_signer: Whether to sign the session id cookie or not.
|
|
:param permanent: Whether to use permanent session or not.
|
|
:param sid_length: The length of the generated session id in bytes.
|
|
:param serialization_format: The serialization format to use for the session data.
|
|
:param table: The table name you want to use.
|
|
:param sequence: The sequence to use for the primary key if needed.
|
|
:param schema: The db schema to use.
|
|
:param bind_key: The db bind key to use.
|
|
:param cleanup_n_requests: Delete expired sessions on average every N requests.
|
|
|
|
.. versionadded:: 0.7
|
|
db changed to client to be standard on all session interfaces.
|
|
The `cleanup_n_request` parameter was added.
|
|
|
|
.. versionadded:: 0.6
|
|
The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added.
|
|
|
|
.. versionadded:: 0.2
|
|
The `use_signer` parameter was added.
|
|
"""
|
|
|
|
session_class = SqlAlchemySession
|
|
ttl = False
|
|
|
|
def __init__(
|
|
self,
|
|
app: Optional[Flask],
|
|
client: Optional[SQLAlchemy] = Defaults.SESSION_SQLALCHEMY,
|
|
key_prefix: str = Defaults.SESSION_KEY_PREFIX,
|
|
use_signer: bool = Defaults.SESSION_USE_SIGNER,
|
|
permanent: bool = Defaults.SESSION_PERMANENT,
|
|
sid_length: int = Defaults.SESSION_ID_LENGTH,
|
|
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT,
|
|
table: str = Defaults.SESSION_SQLALCHEMY_TABLE,
|
|
sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE,
|
|
schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA,
|
|
bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY,
|
|
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
|
|
):
|
|
self.app = app
|
|
|
|
if client is None or not isinstance(client, SQLAlchemy):
|
|
warnings.warn(
|
|
"No valid SQLAlchemy instance provided, attempting to create a new instance on localhost with default settings.",
|
|
RuntimeWarning,
|
|
stacklevel=1,
|
|
)
|
|
client = SQLAlchemy(app)
|
|
self.client = client
|
|
|
|
# Create the session model
|
|
self.sql_session_model = create_session_model(
|
|
client, table, schema, bind_key, sequence
|
|
)
|
|
# Create the table if it does not exist
|
|
with app.app_context():
|
|
if bind_key:
|
|
engine = self.client.get_engine(app, bind=bind_key)
|
|
else:
|
|
engine = self.client.engine
|
|
self.sql_session_model.__table__.create(bind=engine, checkfirst=True)
|
|
|
|
super().__init__(
|
|
app,
|
|
key_prefix,
|
|
use_signer,
|
|
permanent,
|
|
sid_length,
|
|
serialization_format,
|
|
cleanup_n_requests,
|
|
)
|
|
|
|
@retry_query()
|
|
def _delete_expired_sessions(self) -> None:
|
|
try:
|
|
self.client.session.query(self.sql_session_model).filter(
|
|
self.sql_session_model.expiry <= datetime.utcnow()
|
|
).delete(synchronize_session=False)
|
|
self.client.session.commit()
|
|
except Exception:
|
|
self.client.session.rollback()
|
|
raise
|
|
|
|
@retry_query()
|
|
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
|
|
# Get the saved session (record) from the database
|
|
record = self.sql_session_model.query.filter_by(session_id=store_id).first()
|
|
|
|
# "Delete the session record if it is expired as SQL has no TTL ability
|
|
if record and (record.expiry is None or record.expiry <= datetime.utcnow()):
|
|
try:
|
|
self.client.session.delete(record)
|
|
self.client.session.commit()
|
|
except Exception:
|
|
self.client.session.rollback()
|
|
raise
|
|
record = None
|
|
|
|
if record:
|
|
serialized_session_data = want_bytes(record.data)
|
|
return self.serializer.decode(serialized_session_data)
|
|
return None
|
|
|
|
@retry_query()
|
|
def _delete_session(self, store_id: str) -> None:
|
|
try:
|
|
self.sql_session_model.query.filter_by(session_id=store_id).delete()
|
|
self.client.session.commit()
|
|
except Exception:
|
|
self.client.session.rollback()
|
|
raise
|
|
|
|
@retry_query()
|
|
def _upsert_session(
|
|
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str
|
|
) -> None:
|
|
storage_expiration_datetime = datetime.utcnow() + session_lifetime
|
|
|
|
# Serialize session data
|
|
serialized_session_data = self.serializer.encode(session)
|
|
|
|
# Update existing or create new session in the database
|
|
try:
|
|
record = self.sql_session_model.query.filter_by(session_id=store_id).first()
|
|
if record:
|
|
record.data = serialized_session_data
|
|
record.expiry = storage_expiration_datetime
|
|
else:
|
|
record = self.sql_session_model(
|
|
session_id=store_id,
|
|
data=serialized_session_data,
|
|
expiry=storage_expiration_datetime,
|
|
)
|
|
self.client.session.add(record)
|
|
self.client.session.commit()
|
|
except Exception:
|
|
self.client.session.rollback()
|
|
raise
|