Skip to content

Commit

Permalink
Properly type db on the model, sort out generic types issues (#5545)
Browse files Browse the repository at this point in the history
Thanks to @wisp3rwind's suggestion this PR adds types to the
relationship between `Model`, `Database` and `Library`.

Then I worked through the rest of the issues found in the edited files.
Most of this involved providing type parameters for generic types (or
defining defaults, rather 😉).

There `queryparse` module had a somewhat significant issue where the
sorting construction logic only expected to receive `FieldSort`
subclasses, while `SmartArtistSort` was not one. Thus `SmartArtistSort`
has now been forced to behave and is a `FieldSort` subclass. It's also
been moved to `query.py` module which is where the rest of sorts are
defined.
  • Loading branch information
snejus committed Dec 19, 2024
1 parent 994f9b8 commit bcf516b
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 109 deletions.
55 changes: 35 additions & 20 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
NullSort,
Query,
Expand All @@ -47,6 +49,15 @@
if TYPE_CHECKING:
from types import TracebackType

from .query import SQLiteType

D = TypeVar("D", bound="Database", default=Any)
else:
D = TypeVar("D", bound="Database")


FlexAttrs = dict[str, str]


class DBAccessError(Exception):
"""The SQLite database became inaccessible.
Expand Down Expand Up @@ -236,7 +247,7 @@ def __len__(self) -> int:
# Abstract base for model classes.


class Model(ABC):
class Model(ABC, Generic[D]):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., they allow subscript access like
``obj['field']``). The same field set is available via attribute
Expand Down Expand Up @@ -284,12 +295,12 @@ class Model(ABC):
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""

_sorts: dict[str, type[Sort]] = {}
_sorts: dict[str, type[FieldSort]] = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""

_queries: dict[str, type[FieldQuery]] = {}
_queries: dict[str, FieldQueryType] = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
Expand All @@ -306,7 +317,7 @@ class Model(ABC):
"""

@cached_classproperty
def _relation(cls) -> type[Model]:
def _relation(cls):
"""The model that this model is closely related to."""
return cls

Expand Down Expand Up @@ -347,7 +358,7 @@ def _template_funcs(self) -> Mapping[str, Callable[[str], str]]:

# Basic operation.

def __init__(self, db: Database | None = None, **values):
def __init__(self, db: D | None = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
Expand All @@ -363,7 +374,7 @@ def __init__(self, db: Database | None = None, **values):
@classmethod
def _awaken(
cls: type[AnyModel],
db: Database | None = None,
db: D | None = None,
fixed_values: dict[str, Any] = {},
flex_values: dict[str, Any] = {},
) -> AnyModel:
Expand Down Expand Up @@ -393,7 +404,7 @@ def clear_dirty(self):
if self._db:
self._revision = self._db.revision

def _check_db(self, need_id: bool = True) -> Database:
def _check_db(self, need_id: bool = True) -> D:
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
Expand Down Expand Up @@ -574,7 +585,7 @@ def store(self, fields: Iterable[str] | None = None):

# Build assignments for query.
assignments = []
subvars = []
subvars: list[SQLiteType] = []
for key in fields:
if key != "id" and key in self._dirty:
self._dirty.remove(key)
Expand Down Expand Up @@ -637,7 +648,7 @@ def remove(self):
f"DELETE FROM {self._flex_table} WHERE entity_id=?", (self.id,)
)

def add(self, db: Database | None = None):
def add(self, db: D | None = None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
Expand Down Expand Up @@ -714,16 +725,16 @@ def field_query(
cls,
field,
pattern,
query_cls: type[FieldQuery] = MatchQuery,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)

@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping,
query_cls: type[FieldQuery] = MatchQuery,
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
Expand All @@ -749,8 +760,8 @@ class Results(Generic[AnyModel]):
def __init__(
self,
model_class: type[AnyModel],
rows: list[Mapping],
db: Database,
rows: list[sqlite3.Row],
db: D,
flex_rows,
query: Query | None = None,
sort=None,
Expand Down Expand Up @@ -834,9 +845,9 @@ def __iter__(self) -> Iterator[AnyModel]:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()

def _get_indexed_flex_attrs(self) -> Mapping:
def _get_indexed_flex_attrs(self) -> dict[int, FlexAttrs]:
"""Index flexible attributes by the entity id they belong to"""
flex_values: dict[int, dict[str, Any]] = {}
flex_values: dict[int, FlexAttrs] = {}
for row in self.flex_rows:
if row["entity_id"] not in flex_values:
flex_values[row["entity_id"]] = {}
Expand All @@ -845,7 +856,9 @@ def _get_indexed_flex_attrs(self) -> Mapping:

return flex_values

def _make_model(self, row, flex_values: dict = {}) -> AnyModel:
def _make_model(
self, row: sqlite3.Row, flex_values: FlexAttrs = {}
) -> AnyModel:
"""Create a Model object for the given row"""
cols = dict(row)
values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"}
Expand Down Expand Up @@ -954,14 +967,16 @@ def __exit__(
self._mutated = False
self.db._db_lock.release()

def query(self, statement: str, subvals: Sequence = ()) -> list:
def query(
self, statement: str, subvals: Sequence[SQLiteType] = ()
) -> list[sqlite3.Row]:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()

def mutate(self, statement: str, subvals: Sequence = ()) -> Any:
def mutate(self, statement: str, subvals: Sequence[SQLiteType] = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
Expand Down Expand Up @@ -1122,7 +1137,7 @@ def _close(self):
conn.close()

@contextlib.contextmanager
def _tx_stack(self) -> Generator[list]:
def _tx_stack(self) -> Generator[list[Transaction]]:
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
Expand Down
49 changes: 37 additions & 12 deletions beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import unicodedata
from abc import ABC, abstractmethod
from collections.abc import Collection, Iterator, MutableSequence, Sequence
from collections.abc import Iterator, MutableSequence, Sequence
from datetime import datetime, timedelta
from functools import reduce
from operator import mul, or_
Expand All @@ -30,6 +30,11 @@

if TYPE_CHECKING:
from beets.dbcore import Model
from beets.dbcore.db import AnyModel

P = TypeVar("P", default=Any)
else:
P = TypeVar("P")


class ParsingError(ValueError):
Expand Down Expand Up @@ -107,9 +112,9 @@ def __hash__(self) -> int:
return hash(type(self))


P = TypeVar("P")
SQLiteType = Union[str, bytes, float, int, memoryview]
SQLiteType = Union[str, bytes, float, int, memoryview, None]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
FieldQueryType = type["FieldQuery"]


class FieldQuery(Query, Generic[P]):
Expand Down Expand Up @@ -289,7 +294,7 @@ def _normalize(s: str) -> str:
return unicodedata.normalize("NFC", s)

@classmethod
def string_match(cls, pattern: Pattern, value: str) -> bool:
def string_match(cls, pattern: Pattern[str], value: str) -> bool:
return pattern.search(cls._normalize(value)) is not None


Expand Down Expand Up @@ -451,7 +456,7 @@ def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return reduce(or_, (sq.field_names for sq in self.subqueries))

def __init__(self, subqueries: Sequence = ()):
def __init__(self, subqueries: Sequence[Query] = ()):
self.subqueries = subqueries

# Act like a sequence.
Expand All @@ -462,7 +467,7 @@ def __len__(self) -> int:
def __getitem__(self, key):
return self.subqueries[key]

def __iter__(self) -> Iterator:
def __iter__(self) -> Iterator[Query]:
return iter(self.subqueries)

def __contains__(self, subq) -> bool:
Expand All @@ -476,7 +481,7 @@ def clause_with_joiner(
all subqueries with the string joiner (padded by spaces).
"""
clause_parts = []
subvals = []
subvals: list[SQLiteType] = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
Expand Down Expand Up @@ -511,7 +516,7 @@ def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)

def __init__(self, pattern, fields, cls: type[FieldQuery]):
def __init__(self, pattern, fields, cls: FieldQueryType):
self.pattern = pattern
self.fields = fields
self.query_class = cls
Expand Down Expand Up @@ -549,7 +554,7 @@ class MutableCollectionQuery(CollectionQuery):
query is initialized.
"""

subqueries: MutableSequence
subqueries: MutableSequence[Query]

def __setitem__(self, key, value):
self.subqueries[key] = value
Expand Down Expand Up @@ -894,7 +899,7 @@ def order_clause(self) -> str | None:
"""
return None

def sort(self, items: list) -> list:
def sort(self, items: list[AnyModel]) -> list[AnyModel]:
"""Sort the list of objects and return a list."""
return sorted(items)

Expand Down Expand Up @@ -988,7 +993,7 @@ def __init__(
self.ascending = ascending
self.case_insensitive = case_insensitive

def sort(self, objs: Collection):
def sort(self, objs: list[AnyModel]) -> list[AnyModel]:
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
Expand Down Expand Up @@ -1047,7 +1052,7 @@ def is_slow(self) -> bool:
class NullSort(Sort):
"""No sorting. Leave results unsorted."""

def sort(self, items: list) -> list:
def sort(self, items: list[AnyModel]) -> list[AnyModel]:
return items

def __nonzero__(self) -> bool:
Expand All @@ -1061,3 +1066,23 @@ def __eq__(self, other) -> bool:

def __hash__(self) -> int:
return 0


class SmartArtistSort(FieldSort):
"""Sort by artist (either album artist or track artist),
prioritizing the sort field over the raw field.
"""

def order_clause(self):
order = "ASC" if self.ascending else "DESC"
collate = "COLLATE NOCASE" if self.case_insensitive else ""
field = self.field

return f"COALESCE(NULLIF({field}_sort, ''), {field}) {collate} {order}"

def sort(self, objs: list[AnyModel]) -> list[AnyModel]:
def key(o):
val = o[f"{self.field}_sort"] or o[self.field]
return val.lower() if self.case_insensitive else val

return sorted(objs, key=key, reverse=not self.ascending)
Loading

0 comments on commit bcf516b

Please sign in to comment.