Skip to content

Refactor: change row_factory from tuple to dict #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions src/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def protected(*args, **kwargs):
return decorate

class DatabaseDatabase:
def __init__(self, db):
def __init__(self, db: sqlite3.Connection):
self._db = db
self.q = db.cursor()
self._db.row_factory = sqlite3.Row
self.q = self._db.cursor()

# Set up collations
self._db.create_collation("alphanum", _collate_alphanum)
Expand All @@ -63,7 +64,7 @@ def __getattr__(self, attr):
return getattr(self._db, attr)

def get_count(self):
return self.q.fetchone()[0]
return self.q.fetchone()['count(*)']

def save(self):
self.commit()
Expand Down Expand Up @@ -217,19 +218,19 @@ def get_service(self, id=None, key=None) -> Optional[Service]:
error("ID or key required to get service")
return None
service = self.q.fetchone()
return Service(*service)
return Service(**service)

@db_error_default(list())
def get_services(self, enabled=True, disabled=False) -> List[Service]:
services = list()
if enabled:
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 1")
for service in self.q.fetchall():
services.append(Service(*service))
services.append(Service(**service))
if disabled:
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 0")
for service in self.q.fetchall():
services.append(Service(*service))
services.append(Service(**service))
return services

@db_error_default(None)
Expand All @@ -242,7 +243,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
if stream is None:
error("Stream {} not found".format(id))
return None
stream = Stream(*stream)
stream = Stream(**stream)
elif service_tuple is not None:
service, show_key = service_tuple
debug("Getting stream for {}/{}".format(service, show_key))
Expand All @@ -252,7 +253,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
if stream is None:
error("Stream {} not found".format(id))
return None
stream = Stream(*stream)
stream = Stream(**stream)
else:
error("Nothing provided to get stream")
return None
Expand Down Expand Up @@ -299,7 +300,7 @@ def get_streams(self, service=None, show=None, active=True, unmatched=False, mis
return list()

streams = self.q.fetchall()
streams = [Stream(*stream) for stream in streams]
streams = [Stream(**stream) for stream in streams]
for stream in streams:
stream.show = self.get_show(id=stream.show) # convert show id to show model
return streams
Expand Down Expand Up @@ -359,7 +360,7 @@ def get_lite_streams(self, service=None, show=None, missing_link=False) -> List[
return list()

lite_streams = self.q.fetchall()
lite_streams = [LiteStream(*lite_stream) for lite_stream in lite_streams]
lite_streams = [LiteStream(**lite_stream) for lite_stream in lite_streams]
return lite_streams

@db_error
Expand All @@ -381,19 +382,19 @@ def get_link_site(self, id:str=None, key:str=None) -> Optional[LinkSite]:
site = self.q.fetchone()
if site is None:
return None
return LinkSite(*site)
return LinkSite(**site)

@db_error_default(list())
def get_link_sites(self, enabled=True, disabled=False) -> List[LinkSite]:
sites = list()
if enabled:
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 1")
for link in self.q.fetchall():
sites.append(LinkSite(*link))
sites.append(LinkSite(**link))
if disabled:
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 0")
for link in self.q.fetchall():
sites.append(LinkSite(*link))
sites.append(LinkSite(**link))
return sites

@db_error_default(list())
Expand All @@ -404,7 +405,7 @@ def get_links(self, show:Show=None) -> List[Link]:
# Get all streams with show ID
self.q.execute("SELECT site, show, site_key FROM Links WHERE show = ?", (show.id,))
links = self.q.fetchall()
links = [Link(*link) for link in links]
links = [Link(**link) for link in links]
return links
else:
error("A show must be provided to get links")
Expand All @@ -418,7 +419,7 @@ def get_link(self, show: Show, link_site: LinkSite) -> Optional[Link]:
link = self.q.fetchone()
if link is None:
return None
link = Link(*link)
link = Link(**link)
return link

@db_error_default(False)
Expand Down Expand Up @@ -449,15 +450,15 @@ def add_link(self, raw_show: UnprocessedShow, show_id, commit=True):

# Shows
@db_error_default(list())
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> [Show]:
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> list[Show]:
shows = list()
if missing_length:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE (length IS NULL OR length = '' OR length = 0) AND enabled = ?", (enabled,))
elif missing_stream:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows show\
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows show\
WHERE (SELECT count(*) FROM Streams stream, Services service \
WHERE stream.show = show.id \
AND stream.active = 1 \
Expand All @@ -467,14 +468,14 @@ def get_shows(self, missing_length=False, missing_stream=False, enabled=True, de
(enabled,))
elif delayed:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE delayed = 1 AND enabled = ?", (enabled,))
else:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE enabled = ?", (enabled,))
for show in self.q.fetchall():
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
shows.append(show)
return shows
Expand All @@ -492,12 +493,12 @@ def get_show(self, id=None, stream=None) -> Optional[Show]:
error("Show ID not provided to get_show")
return None
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE id = ?", (id,))
show = self.q.fetchone()
if show is None:
return None
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
return show

Expand All @@ -506,19 +507,19 @@ def get_show_by_name(self, name) -> Optional[Show]:
#debug("Getting show from database")

self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE name = ?", (name,))
show = self.q.fetchone()
if show is None:
return None
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
return show

@db_error_default(list())
def get_aliases(self, show: Show) -> [str]:
def get_aliases(self, show: Show) -> list[str]:
self.q.execute("SELECT alias FROM Aliases where show = ?", (show.id,))
return [s for s, in self.q.fetchall()]
return [s["alias"] for s in self.q.fetchall()]

@db_error_default(None)
def add_show(self, raw_show: UnprocessedShow, commit=True) -> int:
Expand Down Expand Up @@ -556,7 +557,7 @@ def update_show(self, show_id: str, raw_show: UnprocessedShow, commit=True):
is_nsfw = raw_show.is_nsfw

if name_en:
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
if length != 0:
self.q.execute("UPDATE Shows SET length = ? WHERE id = ?", (length, show_id))
self.q.execute("UPDATE Shows SET type = ?, has_source = ?, is_nsfw = ? WHERE id = ?", (show_type, has_source, is_nsfw, show_id))
Expand Down Expand Up @@ -599,10 +600,10 @@ def stream_has_episode(self, stream: Stream, episode_num) -> bool:

@db_error_default(None)
def get_latest_episode(self, show: Show) -> Optional[Episode]:
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
data = self.q.fetchone()
if data is not None:
return Episode(data[0], None, data[1], None)
return Episode(**data)
return None

@db_error
Expand All @@ -614,9 +615,9 @@ def add_episode(self, show, episode_num, post_url):
@db_error_default(list())
def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
episodes = list()
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ?", (show.id,))
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ?", (show.id,))
for data in self.q.fetchall():
episodes.append(Episode(data[0], None, data[1], None))
episodes.append(Episode(**data))

if ensure_sorted:
episodes = sorted(episodes, key=lambda e: e.number)
Expand All @@ -625,23 +626,23 @@ def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
# Scores
@db_error_default(list())
def get_show_scores(self, show: Show) -> List[EpisodeScore]:
self.q.execute("SELECT episode, site, score FROM Scores WHERE show=?", (show.id,))
return [EpisodeScore(show.id, *s) for s in self.q.fetchall()]
self.q.execute("SELECT episode, site AS site_id, score FROM Scores WHERE show=?", (show.id,))
return [EpisodeScore(show_id=show.id, **s) for s in self.q.fetchall()]

@db_error_default(list())
def get_episode_scores(self, show: Show, episode: Episode) -> List[EpisodeScore]:
self.q.execute("SELECT site, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
return [EpisodeScore(show.id, episode.number, *s) for s in self.q.fetchall()]
self.q.execute("SELECT site AS site_id, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
return [EpisodeScore(show_id=show.id, episode=episode.number, **s) for s in self.q.fetchall()]

@db_error_default(None)
def get_episode_score_avg(self, show: Show, episode: Episode) -> Optional[EpisodeScore]:
debug("Calculating avg score for {} ({})".format(show.name, show.id))
self.q.execute("SELECT score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
scores = [s[0] for s in self.q.fetchall()]
scores = [s["score"] for s in self.q.fetchall()]
if len(scores) > 0:
score = sum(scores)/len(scores)
debug(" Score: {} (from {} scores)".format(score, len(scores)))
return EpisodeScore(show.id, episode.number, None, score)
return EpisodeScore(show_id=show.id, episode=episode.number, score=score)
return None

@db_error
Expand All @@ -664,7 +665,7 @@ def get_poll_site(self, id:str=None, key:str=None) -> Optional[PollSite]:
site = self.q.fetchone()
if site is None:
return None
return PollSite(*site)
return PollSite(**site)

@db_error
def add_poll(self, show: Show, episode: Episode, site: PollSite, poll_id, commit=True):
Expand All @@ -681,24 +682,24 @@ def update_poll_score(self, poll: Poll, score, commit=True):

@db_error_default(None)
def get_poll(self, show: Show, episode: Episode):
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
poll = self.q.fetchone()
if poll is None:
return None
return Poll(*poll)
return Poll(**poll)

@db_error_default(list())
def get_polls(self, show: Show=None, missing_score=False):
polls = list()
if show is not None:
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ?", (show.id,))
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ?", (show.id,))
elif missing_score:
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
else:
error("Need to select a show to get polls")
return list()
for poll in self.q.fetchall():
polls.append(Poll(*poll))
polls.append(Poll(**poll))
return polls

# Searching
Expand All @@ -713,8 +714,8 @@ def search_show_ids_by_names(self, *names, exact=False) -> Set[Show]:
self.q.execute("SELECT show, name FROM ShowNames WHERE name = ? COLLATE alphanum", (name,))
matched = self.q.fetchall()
for match in matched:
debug(" Found match: {} | {}".format(match[0], match[1]))
shows.add(match[0])
debug(" Found match: {} | {}".format(match['show'], match['name']))
shows.add(match['show'])
return shows

# Helper methods
Expand Down
Loading