This commit is contained in:
@@ -7,7 +7,7 @@ use activitypub::RemoteReviewRepository;
|
||||
use activitypub_base::{
|
||||
BlockedDomain, FederationRepository, Follower, FollowerStatus, FollowingStatus, RemoteActor,
|
||||
};
|
||||
use domain::models::{Review, ReviewSource, RemoteWatchlistEntry};
|
||||
use domain::models::{RemoteWatchlistEntry, Review, ReviewSource};
|
||||
use domain::ports::RemoteWatchlistRepository;
|
||||
|
||||
fn datetime_to_str(dt: &NaiveDateTime) -> String {
|
||||
@@ -178,7 +178,8 @@ impl FederationRepository for SqliteFederationRepository {
|
||||
let status_str: String = row.get("status");
|
||||
let handle: String = row.try_get("handle").unwrap_or_default();
|
||||
let inbox_url: String = row.try_get("inbox_url").unwrap_or_default();
|
||||
let shared_inbox_url: Option<String> = row.try_get("shared_inbox_url").ok().flatten();
|
||||
let shared_inbox_url: Option<String> =
|
||||
row.try_get("shared_inbox_url").ok().flatten();
|
||||
let display_name: Option<String> = row.try_get("display_name").ok().flatten();
|
||||
let avatar_url: Option<String> = row.try_get("avatar_url").ok().flatten();
|
||||
Follower {
|
||||
@@ -595,12 +596,11 @@ impl FederationRepository for SqliteFederationRepository {
|
||||
}
|
||||
|
||||
async fn is_domain_blocked(&self, domain: &str) -> Result<bool> {
|
||||
let count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM blocked_domains WHERE domain = ?1",
|
||||
)
|
||||
.bind(domain)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
let count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM blocked_domains WHERE domain = ?1")
|
||||
.bind(domain)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
Ok(count > 0)
|
||||
}
|
||||
|
||||
@@ -639,7 +639,10 @@ impl FederationRepository for SqliteFederationRepository {
|
||||
.bind(&uid)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
Ok(rows.iter().map(|r| r.get::<String, _>("remote_actor_url")).collect())
|
||||
Ok(rows
|
||||
.iter()
|
||||
.map(|r| r.get::<String, _>("remote_actor_url"))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn is_actor_blocked(&self, local_user_id: uuid::Uuid, actor_url: &str) -> Result<bool> {
|
||||
@@ -789,11 +792,13 @@ impl domain::ports::SocialQueryPort for SqliteFederationRepository {
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|(url, handle, display_name)| domain::ports::RemoteActorInfo {
|
||||
url,
|
||||
handle,
|
||||
display_name,
|
||||
})
|
||||
.map(
|
||||
|(url, handle, display_name)| domain::ports::RemoteActorInfo {
|
||||
url,
|
||||
handle,
|
||||
display_name,
|
||||
},
|
||||
)
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
@@ -822,19 +827,24 @@ impl RemoteWatchlistRepository for SqliteFederationRepository {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_by_ap_id(&self, ap_id: &str, actor_url: &str) -> Result<(), domain::errors::DomainError> {
|
||||
sqlx::query(
|
||||
"DELETE FROM ap_remote_watchlist_entries WHERE ap_id = ? AND actor_url = ?",
|
||||
)
|
||||
.bind(ap_id)
|
||||
.bind(actor_url)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| domain::errors::DomainError::InfrastructureError(e.to_string()))?;
|
||||
async fn remove_by_ap_id(
|
||||
&self,
|
||||
ap_id: &str,
|
||||
actor_url: &str,
|
||||
) -> Result<(), domain::errors::DomainError> {
|
||||
sqlx::query("DELETE FROM ap_remote_watchlist_entries WHERE ap_id = ? AND actor_url = ?")
|
||||
.bind(ap_id)
|
||||
.bind(actor_url)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| domain::errors::DomainError::InfrastructureError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_by_actor_url(&self, actor_url: &str) -> Result<Vec<RemoteWatchlistEntry>, domain::errors::DomainError> {
|
||||
async fn get_by_actor_url(
|
||||
&self,
|
||||
actor_url: &str,
|
||||
) -> Result<Vec<RemoteWatchlistEntry>, domain::errors::DomainError> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT ap_id, actor_url, movie_title, release_year, external_metadata_id, poster_url, added_at \
|
||||
FROM ap_remote_watchlist_entries WHERE actor_url = ? ORDER BY added_at DESC",
|
||||
@@ -844,24 +854,35 @@ impl RemoteWatchlistRepository for SqliteFederationRepository {
|
||||
.await
|
||||
.map_err(|e| domain::errors::DomainError::InfrastructureError(e.to_string()))?;
|
||||
|
||||
rows.into_iter().map(|row| {
|
||||
let added_at_str: String = row.try_get("added_at").unwrap_or_default();
|
||||
let added_at = chrono::NaiveDateTime::parse_from_str(&added_at_str, "%Y-%m-%d %H:%M:%S")
|
||||
.map(|dt| chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(dt, chrono::Utc))
|
||||
.unwrap_or_else(|_| chrono::Utc::now());
|
||||
Ok(RemoteWatchlistEntry {
|
||||
ap_id: row.try_get("ap_id").unwrap_or_default(),
|
||||
actor_url: row.try_get("actor_url").unwrap_or_default(),
|
||||
movie_title: row.try_get("movie_title").unwrap_or_default(),
|
||||
release_year: row.try_get::<i64, _>("release_year").unwrap_or(0) as u16,
|
||||
external_metadata_id: row.try_get("external_metadata_id").ok().flatten(),
|
||||
poster_url: row.try_get("poster_url").ok().flatten(),
|
||||
added_at,
|
||||
rows.into_iter()
|
||||
.map(|row| {
|
||||
let added_at_str: String = row.try_get("added_at").unwrap_or_default();
|
||||
let added_at =
|
||||
chrono::NaiveDateTime::parse_from_str(&added_at_str, "%Y-%m-%d %H:%M:%S")
|
||||
.map(|dt| {
|
||||
chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(
|
||||
dt,
|
||||
chrono::Utc,
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|_| chrono::Utc::now());
|
||||
Ok(RemoteWatchlistEntry {
|
||||
ap_id: row.try_get("ap_id").unwrap_or_default(),
|
||||
actor_url: row.try_get("actor_url").unwrap_or_default(),
|
||||
movie_title: row.try_get("movie_title").unwrap_or_default(),
|
||||
release_year: row.try_get::<i64, _>("release_year").unwrap_or(0) as u16,
|
||||
external_metadata_id: row.try_get("external_metadata_id").ok().flatten(),
|
||||
poster_url: row.try_get("poster_url").ok().flatten(),
|
||||
added_at,
|
||||
})
|
||||
})
|
||||
}).collect()
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn remove_all_by_actor(&self, actor_url: &str) -> Result<(), domain::errors::DomainError> {
|
||||
async fn remove_all_by_actor(
|
||||
&self,
|
||||
actor_url: &str,
|
||||
) -> Result<(), domain::errors::DomainError> {
|
||||
sqlx::query("DELETE FROM ap_remote_watchlist_entries WHERE actor_url = ?")
|
||||
.bind(actor_url)
|
||||
.execute(&self.pool)
|
||||
@@ -870,18 +891,22 @@ impl RemoteWatchlistRepository for SqliteFederationRepository {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_by_derived_uuid(&self, uuid: uuid::Uuid) -> Result<Vec<RemoteWatchlistEntry>, domain::errors::DomainError> {
|
||||
let actors: Vec<String> = sqlx::query("SELECT DISTINCT actor_url FROM ap_remote_watchlist_entries")
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| domain::errors::DomainError::InfrastructureError(e.to_string()))?
|
||||
.into_iter()
|
||||
.filter_map(|row| row.try_get::<String, _>("actor_url").ok())
|
||||
.collect();
|
||||
async fn get_by_derived_uuid(
|
||||
&self,
|
||||
uuid: uuid::Uuid,
|
||||
) -> Result<Vec<RemoteWatchlistEntry>, domain::errors::DomainError> {
|
||||
let actors: Vec<String> =
|
||||
sqlx::query("SELECT DISTINCT actor_url FROM ap_remote_watchlist_entries")
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| domain::errors::DomainError::InfrastructureError(e.to_string()))?
|
||||
.into_iter()
|
||||
.filter_map(|row| row.try_get::<String, _>("actor_url").ok())
|
||||
.collect();
|
||||
|
||||
let target = actors.into_iter().find(|url| {
|
||||
uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_URL, url.as_bytes()) == uuid
|
||||
});
|
||||
let target = actors
|
||||
.into_iter()
|
||||
.find(|url| uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_URL, url.as_bytes()) == uuid);
|
||||
|
||||
match target {
|
||||
None => Ok(vec![]),
|
||||
@@ -890,7 +915,9 @@ impl RemoteWatchlistRepository for SqliteFederationRepository {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wire(pool: sqlx::SqlitePool) -> (
|
||||
pub fn wire(
|
||||
pool: sqlx::SqlitePool,
|
||||
) -> (
|
||||
std::sync::Arc<dyn activitypub::FederationRepository>,
|
||||
std::sync::Arc<dyn domain::ports::SocialQueryPort>,
|
||||
std::sync::Arc<dyn activitypub::RemoteReviewRepository>,
|
||||
|
||||
@@ -3,14 +3,23 @@ use sqlx::SqlitePool;
|
||||
|
||||
async fn test_pool() -> SqlitePool {
|
||||
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||
sqlx::query("CREATE TABLE users (id TEXT PRIMARY KEY, email TEXT, password_hash TEXT, created_at TEXT)")
|
||||
.execute(&pool).await.unwrap();
|
||||
sqlx::query(
|
||||
"CREATE TABLE users (id TEXT PRIMARY KEY, email TEXT, password_hash TEXT, created_at TEXT)",
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
sqlx::query("CREATE TABLE blocked_actors (local_user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, remote_actor_url TEXT NOT NULL, blocked_at TEXT NOT NULL, PRIMARY KEY (local_user_id, remote_actor_url))")
|
||||
.execute(&pool).await.unwrap();
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query("INSERT INTO users (id, email, password_hash, created_at) VALUES (?, ?, ?, ?)")
|
||||
.bind(&uid).bind("a@b.com").bind("hash").bind("2024-01-01")
|
||||
.execute(&pool).await.unwrap();
|
||||
.bind(&uid)
|
||||
.bind("a@b.com")
|
||||
.bind("hash")
|
||||
.bind("2024-01-01")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
pool
|
||||
}
|
||||
|
||||
@@ -19,8 +28,11 @@ async fn block_and_check_actor() {
|
||||
let pool = test_pool().await;
|
||||
let user_id = uuid::Uuid::parse_str(
|
||||
&sqlx::query_scalar::<_, String>("SELECT id FROM users LIMIT 1")
|
||||
.fetch_one(&pool).await.unwrap()
|
||||
).unwrap();
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let repo = SqliteFederationRepository::new(pool);
|
||||
let actor_url = "https://mastodon.social/users/alice";
|
||||
assert!(!repo.is_actor_blocked(user_id, actor_url).await.unwrap());
|
||||
|
||||
@@ -13,7 +13,9 @@ async fn blocked_domain_is_detected() {
|
||||
let pool = test_pool().await;
|
||||
let repo = SqliteFederationRepository::new(pool);
|
||||
assert!(!repo.is_domain_blocked("mastodon.social").await.unwrap());
|
||||
repo.add_blocked_domain("mastodon.social", Some("spam")).await.unwrap();
|
||||
repo.add_blocked_domain("mastodon.social", Some("spam"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(repo.is_domain_blocked("mastodon.social").await.unwrap());
|
||||
}
|
||||
|
||||
@@ -30,7 +32,9 @@ async fn remove_unblocks_domain() {
|
||||
async fn get_blocked_domains_returns_all() {
|
||||
let pool = test_pool().await;
|
||||
let repo = SqliteFederationRepository::new(pool);
|
||||
repo.add_blocked_domain("a.com", Some("reason a")).await.unwrap();
|
||||
repo.add_blocked_domain("a.com", Some("reason a"))
|
||||
.await
|
||||
.unwrap();
|
||||
repo.add_blocked_domain("b.com", None).await.unwrap();
|
||||
let domains = repo.get_blocked_domains().await.unwrap();
|
||||
assert_eq!(domains.len(), 2);
|
||||
|
||||
@@ -14,7 +14,14 @@ async fn test_pool() -> SqlitePool {
|
||||
async fn add_announce_stores_and_counts() {
|
||||
let pool = test_pool().await;
|
||||
let repo = SqliteFederationRepository::new(pool);
|
||||
repo.add_announce("https://remote/ann/1", "https://local/r/1", "https://remote/u/1", Utc::now()).await.unwrap();
|
||||
repo.add_announce(
|
||||
"https://remote/ann/1",
|
||||
"https://local/r/1",
|
||||
"https://remote/u/1",
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(repo.count_announces("https://local/r/1").await.unwrap(), 1);
|
||||
}
|
||||
|
||||
@@ -22,8 +29,22 @@ async fn add_announce_stores_and_counts() {
|
||||
async fn duplicate_announce_is_ignored() {
|
||||
let pool = test_pool().await;
|
||||
let repo = SqliteFederationRepository::new(pool);
|
||||
repo.add_announce("https://remote/ann/1", "https://local/r/1", "https://remote/u/1", Utc::now()).await.unwrap();
|
||||
repo.add_announce("https://remote/ann/1", "https://local/r/1", "https://remote/u/1", Utc::now()).await.unwrap();
|
||||
repo.add_announce(
|
||||
"https://remote/ann/1",
|
||||
"https://local/r/1",
|
||||
"https://remote/u/1",
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
repo.add_announce(
|
||||
"https://remote/ann/1",
|
||||
"https://local/r/1",
|
||||
"https://remote/u/1",
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(repo.count_announces("https://local/r/1").await.unwrap(), 1);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user