storage.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. from typing import Any
  2. from datetime import datetime
  3. import numpy as np
  4. from yarl import URL
  5. from config import settings
  6. from media_observer.article import (
  7. ArchiveCollection,
  8. FrontPage,
  9. Article,
  10. )
  11. from media_observer.storage_abstraction import (
  12. Table,
  13. Reference,
  14. ColumnType,
  15. Column,
  16. UniqueIndex,
  17. View,
  18. StorageAbc,
  19. )
  20. from media_observer.db.sqlite import SqliteBackend
  21. from media_observer.db.postgres import PostgresBackend
  22. from media_observer.internet_archive import InternetArchiveSnapshotId
  23. table_sites = Table(
  24. name="sites",
  25. columns=[
  26. Column(name="id", primary_key=True),
  27. Column(name="name", type_=ColumnType.Text),
  28. Column(name="original_url", type_=ColumnType.Url),
  29. ],
  30. )
  31. table_frontpages = Table(
  32. name="frontpages",
  33. columns=[
  34. Column(name="id", primary_key=True),
  35. Column(
  36. name="site_id",
  37. references=Reference("sites", "id", on_delete="cascade"),
  38. ),
  39. Column(name="timestamp", type_=ColumnType.TimestampTz),
  40. Column(name="timestamp_virtual", type_=ColumnType.TimestampTz),
  41. Column(name="url_original", type_=ColumnType.Url),
  42. Column(name="url_snapshot", type_=ColumnType.Url),
  43. ],
  44. )
  45. table_articles = Table(
  46. name="articles",
  47. columns=[
  48. Column(name="id", primary_key=True),
  49. Column(name="url", type_=ColumnType.Url),
  50. ],
  51. )
  52. table_titles = Table(
  53. name="titles",
  54. columns=[
  55. Column(name="id", primary_key=True),
  56. Column(name="text", type_=ColumnType.Text),
  57. ],
  58. )
  59. table_main_articles = Table(
  60. name="main_articles",
  61. columns=[
  62. Column(name="id", primary_key=True),
  63. Column(name="url", type_=ColumnType.Url),
  64. Column(
  65. name="frontpage_id",
  66. references=Reference("frontpages", "id", on_delete="cascade"),
  67. ),
  68. Column(
  69. name="article_id",
  70. references=Reference("articles", "id", on_delete="cascade"),
  71. ),
  72. Column(
  73. name="title_id",
  74. references=Reference("titles", "id", on_delete="cascade"),
  75. ),
  76. ],
  77. )
  78. table_top_articles = Table(
  79. name="top_articles",
  80. columns=[
  81. Column(name="id", primary_key=True),
  82. Column(name="url", type_=ColumnType.Url),
  83. Column(name="rank", type_=ColumnType.Integer),
  84. Column(
  85. name="frontpage_id",
  86. references=Reference("frontpages", "id", on_delete="cascade"),
  87. ),
  88. Column(
  89. name="article_id",
  90. references=Reference("articles", "id", on_delete="cascade"),
  91. ),
  92. Column(
  93. name="title_id",
  94. references=Reference("titles", "id", on_delete="cascade"),
  95. ),
  96. ],
  97. )
  98. table_embeddings = Table(
  99. name="embeddings",
  100. columns=[
  101. Column(name="id", primary_key=True),
  102. Column(
  103. name="title_id", references=Reference("titles", "id", on_delete="cascade")
  104. ),
  105. Column(name="vector", type_=ColumnType.Vector),
  106. ],
  107. )
  108. view_frontpages = View(
  109. name="frontpages_view",
  110. column_names=[
  111. "id",
  112. "site_id",
  113. "site_name",
  114. "site_original_url",
  115. "timestamp",
  116. "timestamp_virtual",
  117. "archive_snapshot_url",
  118. ],
  119. create_stmt="""
  120. SELECT
  121. fp.id,
  122. si.id AS site_id,
  123. si.name AS site_name,
  124. si.original_url AS site_original_url,
  125. fp.timestamp,
  126. fp.timestamp_virtual,
  127. fp.url_snapshot AS archive_snapshot_url
  128. FROM
  129. frontpages AS fp
  130. JOIN
  131. sites AS si ON si.id = fp.site_id
  132. """,
  133. )
  134. view_articles = View(
  135. name="articles_view",
  136. column_names=[
  137. "id",
  138. "title",
  139. "title_id",
  140. "url_archive",
  141. "url_article",
  142. "main_in_frontpage_id",
  143. "top_in_frontpage_id",
  144. "rank",
  145. ],
  146. create_stmt="""
  147. SELECT
  148. a.id,
  149. t.text AS title,
  150. t.id AS title_id,
  151. ma.url AS url_archive,
  152. a.url AS url_article,
  153. ma.frontpage_id AS main_in_frontpage_id,
  154. NULL AS top_in_frontpage_id,
  155. NULL AS rank
  156. FROM articles a
  157. JOIN main_articles ma ON ma.article_id = a.id
  158. JOIN titles t ON t.id = ma.title_id
  159. UNION ALL
  160. SELECT
  161. a.id,
  162. t.text AS title,
  163. t.id AS title_id,
  164. ta.url AS url_archive,
  165. a.url AS url_article,
  166. NULL AS main_in_frontpage_id,
  167. ta.frontpage_id AS top_in_frontpage_id,
  168. ta.rank
  169. FROM articles a
  170. JOIN top_articles ta ON ta.article_id = a.id
  171. JOIN titles t ON t.id = ta.title_id
  172. """,
  173. )
  174. view_articles_on_frontpage = View(
  175. name="articles_on_frontpage_view",
  176. column_names=[
  177. "frontpage_id",
  178. "site_id",
  179. "site_name",
  180. "site_original_url",
  181. "timestamp",
  182. "timestamp_virtual",
  183. "archive_snapshot_url",
  184. "article_id",
  185. "title",
  186. "title_id",
  187. "url_archive",
  188. "url_article",
  189. "is_main",
  190. "rank",
  191. ],
  192. create_stmt="""
  193. SELECT
  194. fpv.id AS frontpage_id,
  195. fpv.site_id,
  196. fpv.site_name,
  197. fpv.site_original_url,
  198. fpv."timestamp",
  199. fpv.timestamp_virtual,
  200. fpv.archive_snapshot_url,
  201. av.id AS article_id,
  202. av.title,
  203. av.title_id,
  204. av.url_archive,
  205. av.url_article,
  206. av.main_in_frontpage_id IS NOT NULL AS is_main,
  207. av.rank
  208. FROM articles_view av
  209. JOIN frontpages_view fpv ON fpv.id = av.main_in_frontpage_id OR fpv.id = av.top_in_frontpage_id
  210. """,
  211. )
  212. class Storage(StorageAbc):
  213. tables = [
  214. table_sites,
  215. table_frontpages,
  216. table_articles,
  217. table_titles,
  218. table_main_articles,
  219. table_top_articles,
  220. table_embeddings,
  221. ]
  222. views = [
  223. view_frontpages,
  224. view_articles,
  225. view_articles_on_frontpage,
  226. ]
  227. indexes = [
  228. UniqueIndex(table="sites", columns=["name"]),
  229. UniqueIndex(table="frontpages", columns=["timestamp_virtual", "site_id"]),
  230. UniqueIndex(table="articles", columns=["url"]),
  231. UniqueIndex(table="titles", columns=["text"]),
  232. UniqueIndex(table="main_articles", columns=["frontpage_id", "article_id"]),
  233. UniqueIndex(
  234. table="top_articles", columns=["frontpage_id", "article_id", "rank"]
  235. ),
  236. UniqueIndex(table="embeddings", columns=["title_id"]),
  237. ]
  238. def __init__(self, backend):
  239. self.backend = backend
  240. async def close(self):
  241. await self.backend.close()
  242. @staticmethod
  243. async def create():
  244. # We try to reproduce the scheme used by SQLAlchemy for Database-URLs
  245. # https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
  246. conn_url = URL(settings.database_url)
  247. backend = None
  248. if conn_url.scheme == "sqlite":
  249. if conn_url.path.startswith("//"):
  250. raise ValueError("Absolute URLs not supported for sqlite")
  251. elif conn_url.path.startswith("/"):
  252. backend = await SqliteBackend.create(conn_url.path[1:])
  253. elif conn_url.scheme == "postgresql":
  254. backend = await PostgresBackend.create(settings.database_url)
  255. else:
  256. raise ValueError("Only the SQLite backend is supported")
  257. storage = Storage(backend)
  258. await storage._create_db()
  259. return storage
  260. async def _create_db(self):
  261. async with self.backend.get_connection() as conn:
  262. for t in self.tables:
  263. await t.create_if_not_exists(conn)
  264. for i in self.indexes:
  265. await i.create_if_not_exists(conn)
  266. for v in self.views:
  267. await v.create_if_not_exists(conn)
  268. async def exists_frontpage(self, name: str, dt: datetime):
  269. async with self.backend.get_connection() as conn:
  270. exists = await conn.execute_fetchall(
  271. """
  272. SELECT 1
  273. FROM frontpages f
  274. JOIN sites s ON s.id = f.site_id
  275. WHERE s.name = $1 AND timestamp_virtual = $2
  276. """,
  277. name,
  278. dt,
  279. )
  280. return exists != []
  281. @classmethod
  282. def _from_row(cls, r, table_or_view: Table | View):
  283. columns = table_or_view.column_names
  284. return {col: r[idx] for idx, col in enumerate(columns)}
  285. async def list_neighbouring_main_articles(
  286. self,
  287. site_id: int,
  288. timestamp: datetime | None = None,
  289. ):
  290. async with self.backend.get_connection() as conn:
  291. if timestamp is None:
  292. [row] = await conn.execute_fetchall(
  293. """
  294. SELECT timestamp_virtual
  295. FROM frontpages_view
  296. WHERE site_id = $1
  297. ORDER BY timestamp_virtual DESC
  298. LIMIT 1
  299. """,
  300. site_id,
  301. )
  302. timestamp = row["timestamp_virtual"]
  303. # This query is the union of 3 queries that respectively fetch :
  304. # * articles published at the same time as the queried article (including the queried article)
  305. # * the article published just after, on the same site
  306. # *the article published just before, on the same site
  307. main_articles = await conn.execute_fetchall(
  308. """
  309. WITH aof_diff AS (
  310. SELECT aof.*, EXTRACT(EPOCH FROM aof.timestamp_virtual - $2) :: integer AS time_diff
  311. FROM articles_on_frontpage_view aof
  312. )
  313. SELECT * FROM (
  314. SELECT * FROM aof_diff
  315. WHERE is_main AND time_diff = 0
  316. )
  317. UNION ALL
  318. SELECT * FROM (
  319. SELECT * FROM aof_diff
  320. WHERE is_main AND site_id = $1 AND time_diff > 0
  321. ORDER BY time_diff
  322. LIMIT 1
  323. )
  324. UNION ALL
  325. SELECT * FROM (
  326. SELECT * FROM aof_diff
  327. WHERE is_main AND site_id = $1 AND time_diff < 0
  328. ORDER BY time_diff DESC
  329. LIMIT 1
  330. )
  331. """,
  332. site_id,
  333. timestamp,
  334. )
  335. return [
  336. self._from_row(a, self._view_by_name["articles_on_frontpage_view"])
  337. | {"time_diff": a[14]}
  338. for a in main_articles
  339. ]
  340. async def list_all_titles_without_embedding(self):
  341. async with self.backend.get_connection() as conn:
  342. rows = await conn.execute_fetchall("""
  343. SELECT t.*
  344. FROM public.titles AS t
  345. WHERE NOT EXISTS (SELECT 1 FROM embeddings WHERE title_id = t.id)
  346. """)
  347. return [self._from_row(r, self._table_by_name["titles"]) for r in rows]
  348. async def list_all_embeddings(self):
  349. async with self.backend.get_connection() as conn:
  350. rows = await conn.execute_fetchall(
  351. """
  352. SELECT *
  353. FROM embeddings
  354. """,
  355. )
  356. return [self._from_embeddings_row(r) for r in rows]
  357. async def list_articles_on_frontpage(self, title_ids: list[int]):
  358. if len(title_ids) == 0:
  359. return []
  360. async with self.backend.get_connection() as conn:
  361. rows = await conn.execute_fetchall(
  362. f"""
  363. SELECT *
  364. FROM articles_on_frontpage_view
  365. WHERE title_id IN ({self._placeholders(*title_ids)})
  366. """,
  367. *title_ids,
  368. )
  369. return [
  370. self._from_row(r, self._view_by_name["articles_on_frontpage_view"])
  371. for r in rows
  372. ]
  373. @classmethod
  374. def _from_embeddings_row(cls, r):
  375. [embeds_table] = [t for t in cls.tables if t.name == "embeddings"]
  376. d = cls._from_row(r, embeds_table)
  377. d.update(vector=np.frombuffer(d["vector"], dtype="float32"))
  378. return d
  379. async def add_embedding(self, title_id: int, embedding):
  380. async with self.backend.get_connection() as conn:
  381. await conn.execute_insert(
  382. self._insert_stmt(
  383. "embeddings",
  384. ["title_id", "vector"],
  385. ),
  386. title_id,
  387. embedding,
  388. )
  389. async def list_sites(self):
  390. async with self.backend.get_connection() as conn:
  391. sites = await conn.execute_fetchall("SELECT * FROM sites")
  392. return [self._from_row(s, self._table_by_name["sites"]) for s in sites]
  393. async def add_page(
  394. self, collection: ArchiveCollection, page: FrontPage, dt: datetime
  395. ):
  396. assert dt.tzinfo is not None
  397. async with self.backend.get_connection() as conn:
  398. async with conn.transaction():
  399. site_id = await self._add_site(conn, collection.name, collection.url)
  400. frontpage_id = await self._add_frontpage(
  401. conn, site_id, page.snapshot.id, dt
  402. )
  403. article_id = await self._add_article(
  404. conn, page.main_article.article.original
  405. )
  406. title_id = await self._add_title(conn, page.main_article.article.title)
  407. await self._add_main_article(
  408. conn,
  409. frontpage_id,
  410. article_id,
  411. title_id,
  412. page.main_article.article.url,
  413. )
  414. for t in page.top_articles:
  415. article_id = await self._add_article(conn, t.article.original)
  416. title_id = await self._add_title(conn, t.article.title)
  417. await self._add_top_article(
  418. conn, frontpage_id, article_id, title_id, t.article.url, t.rank
  419. )
  420. return site_id
  421. async def _add_site(self, conn, name: str, original_url: str) -> int:
  422. return await self._insert_or_get(
  423. conn,
  424. self._insert_stmt("sites", ["name", "original_url"]),
  425. [name, original_url],
  426. "SELECT id FROM sites WHERE name = $1",
  427. [name],
  428. )
  429. async def _add_frontpage(
  430. self, conn, site_id: int, snapshot: InternetArchiveSnapshotId, virtual: datetime
  431. ) -> int:
  432. return await self._insert_or_get(
  433. conn,
  434. self._insert_stmt(
  435. "frontpages",
  436. [
  437. "timestamp",
  438. "site_id",
  439. "timestamp_virtual",
  440. "url_original",
  441. "url_snapshot",
  442. ],
  443. ),
  444. [snapshot.timestamp, site_id, virtual, snapshot.original, snapshot.url],
  445. "SELECT id FROM frontpages WHERE timestamp_virtual = $1 AND site_id = $2",
  446. [virtual, site_id],
  447. )
  448. async def _add_article(self, conn, article: Article):
  449. return await self._insert_or_get(
  450. conn,
  451. self._insert_stmt("articles", ["url"]),
  452. [str(article.url)],
  453. "SELECT id FROM articles WHERE url = $1",
  454. [str(article.url)],
  455. )
  456. async def _add_title(self, conn, title: str):
  457. return await self._insert_or_get(
  458. conn,
  459. self._insert_stmt("titles", ["text"]),
  460. [title],
  461. "SELECT id FROM titles WHERE text = $1",
  462. [title],
  463. )
  464. async def _add_main_article(
  465. self, conn, frontpage_id: int, article_id: int, title_id: int, url: URL
  466. ):
  467. await conn.execute_insert(
  468. self._insert_stmt(
  469. "main_articles", ["frontpage_id", "article_id", "title_id", "url"]
  470. ),
  471. frontpage_id,
  472. article_id,
  473. title_id,
  474. str(url),
  475. )
  476. async def _add_top_article(
  477. self,
  478. conn,
  479. frontpage_id: int,
  480. article_id: int,
  481. title_id: int,
  482. url: URL,
  483. rank: int,
  484. ):
  485. await conn.execute_insert(
  486. self._insert_stmt(
  487. "top_articles",
  488. ["frontpage_id", "article_id", "title_id", "url", "rank"],
  489. ),
  490. frontpage_id,
  491. article_id,
  492. title_id,
  493. str(url),
  494. rank,
  495. )
  496. async def _insert_or_get(
  497. self,
  498. conn,
  499. insert_stmt: str,
  500. insert_args: list[Any],
  501. select_stmt: str,
  502. select_args: list[Any],
  503. ) -> int:
  504. await conn.execute_insert(insert_stmt, *insert_args)
  505. [(id_,)] = await conn.execute_fetchall(select_stmt, *select_args)
  506. return id_
  507. @staticmethod
  508. def _insert_stmt(table, cols):
  509. cols_str = ", ".join(cols)
  510. return f"""
  511. INSERT INTO {table} ({cols_str})
  512. VALUES ({Storage._placeholders(*cols)})
  513. ON CONFLICT DO NOTHING
  514. """
  515. @staticmethod
  516. def _placeholders(*args):
  517. return ", ".join([f"${idx + 1}" for idx, _ in enumerate(args)])
  518. @property
  519. def _table_by_name(self):
  520. return {t.name: t for t in self.tables}
  521. @property
  522. def _view_by_name(self):
  523. return {v.name: v for v in self.views}