feat: Upgrade k-core dependency to version 0.1.10, refactor message broker and embedding components, and enhance session store integration
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -1815,15 +1815,27 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "k-core"
|
||||
version = "0.1.5"
|
||||
source = "git+https://git.gabrielkaszewski.dev/GKaszewski/k-core#667cae596cf4e6c9c8e4cfa3bd5ee53ffb0796fb"
|
||||
version = "0.1.10"
|
||||
source = "git+https://git.gabrielkaszewski.dev/GKaszewski/k-core#7a72f5f54ad45ba82f451e90c44c0581d13194d9"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-nats",
|
||||
"async-trait",
|
||||
"axum 0.8.8",
|
||||
"chrono",
|
||||
"fastembed",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"qdrant-client",
|
||||
"serde",
|
||||
"sqlx",
|
||||
"thiserror 2.0.17",
|
||||
"time",
|
||||
"tokio",
|
||||
"tower 0.5.2",
|
||||
"tower-http",
|
||||
"tower-sessions",
|
||||
"tower-sessions-sqlx-store",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
@@ -2196,21 +2208,18 @@ dependencies = [
|
||||
name = "notes-infra"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-nats",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"fastembed",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"k-core",
|
||||
"notes-domain",
|
||||
"qdrant-client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tower-sessions",
|
||||
"tower-sessions-sqlx-store",
|
||||
"tracing",
|
||||
"uuid",
|
||||
|
||||
@@ -62,4 +62,7 @@ dotenvy = "0.15.7"
|
||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||
"logging",
|
||||
"db-sqlx",
|
||||
"sqlite",
|
||||
"http",
|
||||
"auth","sessions-db"
|
||||
] }
|
||||
@@ -2,14 +2,16 @@
|
||||
//!
|
||||
//! A high-performance, self-hosted note-taking API following hexagonal architecture.
|
||||
|
||||
use k_core::db::DatabasePool;
|
||||
use k_core::{
|
||||
db::DatabasePool,
|
||||
http::server::{ServerConfig, apply_standard_middleware},
|
||||
};
|
||||
use std::{sync::Arc, time::Duration as StdDuration};
|
||||
use time::Duration;
|
||||
|
||||
use axum::Router;
|
||||
use axum_login::AuthManagerLayerBuilder;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use tower_sessions::{Expiry, SessionManagerLayer};
|
||||
|
||||
use notes_infra::run_migrations;
|
||||
@@ -113,7 +115,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
|
||||
// Auth backend
|
||||
let backend = AuthBackend::new(user_repo);
|
||||
let backend = AuthBackend::new(user_repo); // no idea what now with this
|
||||
|
||||
// Session layer
|
||||
// Use the factory to build the session store, agnostic of the underlying DB
|
||||
@@ -126,47 +128,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
.with_secure(false) // Set to true in production with HTTPS
|
||||
.with_expiry(Expiry::OnInactivity(Duration::seconds(60 * 60 * 24 * 7))); // 7 days
|
||||
.with_secure(false) // Set to true in prod
|
||||
.with_expiry(Expiry::OnInactivity(Duration::days(7)));
|
||||
|
||||
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
|
||||
|
||||
let mut cors = CorsLayer::new()
|
||||
.allow_methods([
|
||||
axum::http::Method::GET,
|
||||
axum::http::Method::POST,
|
||||
axum::http::Method::PATCH,
|
||||
axum::http::Method::DELETE,
|
||||
axum::http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers([
|
||||
axum::http::header::AUTHORIZATION,
|
||||
axum::http::header::ACCEPT,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
])
|
||||
.allow_credentials(true);
|
||||
|
||||
let mut allowed_origins = Vec::new();
|
||||
for origin in &config.cors_allowed_origins {
|
||||
tracing::debug!("Allowing CORS origin: {}", origin);
|
||||
if let Ok(value) = origin.parse::<axum::http::HeaderValue>() {
|
||||
allowed_origins.push(value);
|
||||
} else {
|
||||
tracing::warn!("Invalid CORS origin: {}", origin);
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed_origins.is_empty() {
|
||||
cors = cors.allow_origin(allowed_origins);
|
||||
}
|
||||
let server_config = ServerConfig {
|
||||
cors_origins: config.cors_allowed_origins.clone(),
|
||||
session_secret: Some(config.session_secret.clone()),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", routes::api_v1_router())
|
||||
.layer(auth_layer)
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(state);
|
||||
|
||||
let app = apply_standard_middleware(app, &server_config);
|
||||
|
||||
let addr = format!("{}:{}", config.host, config.port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
|
||||
|
||||
@@ -5,31 +5,30 @@ edition = "2024"
|
||||
|
||||
[features]
|
||||
default = ["sqlite", "smart-features", "broker-nats"]
|
||||
sqlite = ["sqlx/sqlite", "tower-sessions-sqlx-store/sqlite"]
|
||||
postgres = ["sqlx/postgres", "tower-sessions-sqlx-store/postgres"]
|
||||
smart-features = ["dep:fastembed", "dep:qdrant-client"]
|
||||
broker-nats = ["dep:async-nats", "dep:futures-util"]
|
||||
sqlite = ["sqlx/sqlite", "k-core/sqlite", "tower-sessions-sqlx-store", "k-core/sessions-db"]
|
||||
postgres = ["sqlx/postgres", "k-core/postgres", "tower-sessions-sqlx-store", "k-core/sessions-db"]
|
||||
smart-features = ["k-core/ai"]
|
||||
broker-nats = ["dep:futures-util", "k-core/broker-nats"]
|
||||
|
||||
[dependencies]
|
||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||
"logging",
|
||||
"db-sqlx",
|
||||
"sessions-db"
|
||||
], version = "*"}
|
||||
notes-domain = { path = "../notes-domain" }
|
||||
async-trait = "0.1.89"
|
||||
|
||||
chrono = { version = "0.4.42", features = ["serde"] }
|
||||
sqlx = { version = "0.8.6", features = ["runtime-tokio", "chrono", "migrate"] }
|
||||
thiserror = "2.0.17"
|
||||
tokio = { version = "1.48.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1.19.0", features = ["v4", "serde"] }
|
||||
tower-sessions = "0.14.0"
|
||||
tower-sessions-sqlx-store = { version = "0.15.0", default-features = false }
|
||||
fastembed = { version = "5.4", optional = true }
|
||||
qdrant-client = { version = "1.16", optional = true }
|
||||
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
async-nats = { version = "0.45", optional = true }
|
||||
futures-util = { version = "0.3", optional = true }
|
||||
futures-core = "0.3"
|
||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||
"logging",
|
||||
"db-sqlx",
|
||||
"sqlite"
|
||||
], version = "*"}
|
||||
async-trait = "0.1.89"
|
||||
anyhow = "1.0.100"
|
||||
tower-sessions-sqlx-store = { version = "0.15.0", optional = true}
|
||||
|
||||
@@ -6,23 +6,16 @@ use std::pin::Pin;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use k_core::broker::{MessageBroker as CoreBroker, nats::NatsBroker};
|
||||
use notes_domain::{DomainError, DomainResult, MessageBroker, Note};
|
||||
|
||||
/// NATS adapter implementing the MessageBroker port.
|
||||
pub struct NatsMessageBroker {
|
||||
client: async_nats::Client,
|
||||
inner: NatsBroker,
|
||||
}
|
||||
|
||||
impl NatsMessageBroker {
|
||||
/// Create a new NATS message broker by connecting to the given URL.
|
||||
pub async fn connect(url: &str) -> Result<Self, async_nats::ConnectError> {
|
||||
let client = async_nats::connect(url).await?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Create a NATS message broker from an existing client.
|
||||
pub fn from_client(client: async_nats::Client) -> Self {
|
||||
Self { client }
|
||||
pub fn new(broker: NatsBroker) -> Self {
|
||||
Self { inner: broker }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +26,7 @@ impl MessageBroker for NatsMessageBroker {
|
||||
DomainError::RepositoryError(format!("Failed to serialize note: {}", e))
|
||||
})?;
|
||||
|
||||
self.client
|
||||
self.inner
|
||||
.publish("notes.updated", payload.into())
|
||||
.await
|
||||
.map_err(|e| DomainError::RepositoryError(format!("Failed to publish event: {}", e)))?;
|
||||
@@ -44,15 +37,14 @@ impl MessageBroker for NatsMessageBroker {
|
||||
async fn subscribe_note_updates(
|
||||
&self,
|
||||
) -> DomainResult<Pin<Box<dyn futures_core::Stream<Item = Note> + Send>>> {
|
||||
let subscriber = self
|
||||
.client
|
||||
.subscribe("notes.updated")
|
||||
.await
|
||||
.map_err(|e| DomainError::RepositoryError(format!("Failed to subscribe: {}", e)))?;
|
||||
let stream =
|
||||
self.inner.subscribe("notes.updated").await.map_err(|e| {
|
||||
DomainError::RepositoryError(format!("Broker subscribe error: {}", e))
|
||||
})?;
|
||||
|
||||
// Transform the NATS message stream into a Note stream
|
||||
let note_stream = subscriber.filter_map(|msg| async move {
|
||||
match serde_json::from_slice::<Note>(&msg.payload) {
|
||||
// Map generic bytes back to Domain Note
|
||||
let note_stream = stream.filter_map(|bytes| async move {
|
||||
match serde_json::from_slice::<Note>(&bytes) {
|
||||
Ok(note) => Some(note),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to deserialize note from message: {}", e);
|
||||
|
||||
@@ -1,48 +1,29 @@
|
||||
use async_trait::async_trait;
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use k_core::ai::embeddings::FastEmbedAdapter as CoreFastEmbed;
|
||||
use notes_domain::errors::{DomainError, DomainResult};
|
||||
use notes_domain::ports::EmbeddingGenerator;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FastEmbedAdapter {
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
inner: Arc<CoreFastEmbed>,
|
||||
}
|
||||
|
||||
impl FastEmbedAdapter {
|
||||
pub fn new() -> DomainResult<Self> {
|
||||
let mut options = InitOptions::default();
|
||||
options.model_name = EmbeddingModel::AllMiniLML6V2;
|
||||
options.show_download_progress = false;
|
||||
|
||||
let model = TextEmbedding::try_new(options).map_err(|e| {
|
||||
DomainError::InfrastructureError(format!("Failed to init fastembed: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
})
|
||||
pub fn new(inner: CoreFastEmbed) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingGenerator for FastEmbedAdapter {
|
||||
async fn generate_embedding(&self, text: &str) -> DomainResult<Vec<f32>> {
|
||||
let model = self.model.clone();
|
||||
let text = text.to_string();
|
||||
|
||||
let embeddings = tokio::task::spawn_blocking(move || {
|
||||
let mut model = model.lock().map_err(|e| format!("Lock error: {}", e))?;
|
||||
model
|
||||
.embed(vec![text], None)
|
||||
.map_err(|e| format!("Embed error: {}", e))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| DomainError::InfrastructureError(format!("Join error: {}", e)))?
|
||||
.map_err(|e| DomainError::InfrastructureError(e))?;
|
||||
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| DomainError::InfrastructureError("No embedding generated".to_string()))
|
||||
self.inner
|
||||
.generate_embedding_async(text)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DomainError::InfrastructureError(format!("Embedding generation failed: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,21 @@ use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
use crate::{SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository};
|
||||
use k_core::db::DatabaseConfig;
|
||||
use k_core::db::DatabasePool;
|
||||
use k_core::session::store::InfraSessionStore;
|
||||
use notes_domain::{NoteRepository, TagRepository, UserRepository};
|
||||
|
||||
#[cfg(feature = "smart-features")]
|
||||
use crate::embeddings::fastembed::FastEmbedAdapter;
|
||||
#[cfg(feature = "smart-features")]
|
||||
use crate::vector::qdrant::QdrantVectorAdapter;
|
||||
#[cfg(feature = "smart-features")]
|
||||
use k_core::ai::{
|
||||
embeddings::FastEmbedAdapter as CoreFastEmbed, qdrant::QdrantAdapter as CoreQdrant,
|
||||
};
|
||||
#[cfg(feature = "smart-features")]
|
||||
use k_core::broker::nats::NatsBroker;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FactoryError {
|
||||
#[error("Database error: {0}")]
|
||||
@@ -16,7 +27,7 @@ pub enum FactoryError {
|
||||
Infrastructure(#[from] notes_domain::DomainError),
|
||||
}
|
||||
|
||||
pub type FactoryResult<T> = Result<T, FactoryError>;
|
||||
pub type FactoryResult<T> = anyhow::Result<T>;
|
||||
|
||||
#[cfg(feature = "smart-features")]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -39,8 +50,8 @@ pub async fn build_embedding_generator(
|
||||
) -> FactoryResult<Arc<dyn notes_domain::ports::EmbeddingGenerator>> {
|
||||
match provider {
|
||||
EmbeddingProvider::FastEmbed => {
|
||||
let adapter = crate::embeddings::fastembed::FastEmbedAdapter::new()?;
|
||||
Ok(Arc::new(adapter))
|
||||
let core_embed = CoreFastEmbed::new()?;
|
||||
Ok(Arc::new(FastEmbedAdapter::new(core_embed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -51,8 +62,9 @@ pub async fn build_vector_store(
|
||||
) -> FactoryResult<Arc<dyn notes_domain::ports::VectorStore>> {
|
||||
match provider {
|
||||
VectorProvider::Qdrant { url, collection } => {
|
||||
let adapter = crate::vector::qdrant::QdrantVectorAdapter::new(url, collection)?;
|
||||
adapter.create_collection_if_not_exists().await?;
|
||||
let core_qdrant = CoreQdrant::new(url, collection)?;
|
||||
let adapter = QdrantVectorAdapter::new(core_qdrant);
|
||||
adapter.init().await.map_err(|e| anyhow::anyhow!(e))?;
|
||||
Ok(Arc::new(adapter))
|
||||
}
|
||||
}
|
||||
@@ -76,14 +88,9 @@ pub async fn build_message_broker(
|
||||
match provider {
|
||||
#[cfg(feature = "broker-nats")]
|
||||
BrokerProvider::Nats { url } => {
|
||||
let broker = crate::broker::nats::NatsMessageBroker::connect(url)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
FactoryError::Infrastructure(notes_domain::DomainError::RepositoryError(
|
||||
format!("NATS connection failed: {}", e),
|
||||
))
|
||||
})?;
|
||||
Ok(Some(Arc::new(broker)))
|
||||
let core_broker = NatsBroker::connect(url).await?;
|
||||
let adapter = crate::broker::nats::NatsMessageBroker::new(core_broker);
|
||||
Ok(Some(Arc::new(adapter)))
|
||||
}
|
||||
BrokerProvider::None => Ok(None),
|
||||
}
|
||||
@@ -100,53 +107,14 @@ pub async fn build_link_repository(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_database_pool(db_config: &DatabaseConfig) -> FactoryResult<DatabasePool> {
|
||||
if db_config.url.starts_with("sqlite:") {
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
let pool = sqlx::sqlite::SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&db_config.url)
|
||||
.await?;
|
||||
Ok(DatabasePool::Sqlite(pool))
|
||||
}
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
Err(FactoryError::NotImplemented(
|
||||
"SQLite feature not enabled".to_string(),
|
||||
))
|
||||
} else if db_config.url.starts_with("postgres:") {
|
||||
#[cfg(feature = "postgres")]
|
||||
{
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&db_config.url)
|
||||
.await?;
|
||||
Ok(DatabasePool::Postgres(pool))
|
||||
}
|
||||
#[cfg(not(feature = "postgres"))]
|
||||
Err(FactoryError::NotImplemented(
|
||||
"Postgres feature not enabled".to_string(),
|
||||
))
|
||||
} else {
|
||||
Err(FactoryError::NotImplemented(format!(
|
||||
"Unsupported database URL scheme in: {}",
|
||||
db_config.url
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_note_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn NoteRepository>> {
|
||||
match pool {
|
||||
#[cfg(feature = "sqlite")]
|
||||
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteNoteRepository::new(pool.clone()))),
|
||||
#[cfg(feature = "postgres")]
|
||||
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
|
||||
"Postgres NoteRepository".to_string(),
|
||||
)),
|
||||
DatabasePool::Postgres(_) => anyhow::bail!("Postgres NoteRepository not implemented"),
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(FactoryError::NotImplemented(
|
||||
"No database feature enabled".to_string(),
|
||||
)),
|
||||
_ => anyhow::bail!("No database feature enabled"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,13 +123,9 @@ pub async fn build_tag_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn
|
||||
#[cfg(feature = "sqlite")]
|
||||
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteTagRepository::new(pool.clone()))),
|
||||
#[cfg(feature = "postgres")]
|
||||
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
|
||||
"Postgres TagRepository".to_string(),
|
||||
)),
|
||||
DatabasePool::Postgres(_) => anyhow::bail!("Postgres TagRepository not implemented"),
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(FactoryError::NotImplemented(
|
||||
"No database feature enabled".to_string(),
|
||||
)),
|
||||
_ => anyhow::bail!("No database feature enabled"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,33 +134,21 @@ pub async fn build_user_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn
|
||||
#[cfg(feature = "sqlite")]
|
||||
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteUserRepository::new(pool.clone()))),
|
||||
#[cfg(feature = "postgres")]
|
||||
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
|
||||
"Postgres UserRepository".to_string(),
|
||||
)),
|
||||
DatabasePool::Postgres(_) => anyhow::bail!("Postgres UserRepository not implemented"),
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(FactoryError::NotImplemented(
|
||||
"No database feature enabled".to_string(),
|
||||
)),
|
||||
_ => anyhow::bail!("No database feature enabled"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_session_store(
|
||||
pool: &DatabasePool,
|
||||
) -> FactoryResult<crate::session_store::InfraSessionStore> {
|
||||
match pool {
|
||||
pub async fn build_session_store(pool: &DatabasePool) -> Result<InfraSessionStore, sqlx::Error> {
|
||||
Ok(match pool {
|
||||
#[cfg(feature = "sqlite")]
|
||||
DatabasePool::Sqlite(pool) => {
|
||||
let store = tower_sessions_sqlx_store::SqliteStore::new(pool.clone());
|
||||
Ok(crate::session_store::InfraSessionStore::Sqlite(store))
|
||||
DatabasePool::Sqlite(p) => {
|
||||
InfraSessionStore::Sqlite(tower_sessions_sqlx_store::SqliteStore::new(p.clone()))
|
||||
}
|
||||
#[cfg(feature = "postgres")]
|
||||
DatabasePool::Postgres(pool) => {
|
||||
let store = tower_sessions_sqlx_store::PostgresStore::new(pool.clone());
|
||||
Ok(crate::session_store::InfraSessionStore::Postgres(store))
|
||||
DatabasePool::Postgres(p) => {
|
||||
InfraSessionStore::Postgres(tower_sessions_sqlx_store::PostgresStore::new(p.clone()))
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(FactoryError::NotImplemented(
|
||||
"No database feature enabled".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,73 +1 @@
|
||||
use async_trait::async_trait;
|
||||
use sqlx;
|
||||
use tower_sessions::{
|
||||
SessionStore,
|
||||
session::{Id, Record},
|
||||
};
|
||||
#[cfg(feature = "postgres")]
|
||||
use tower_sessions_sqlx_store::PostgresStore;
|
||||
#[cfg(feature = "sqlite")]
|
||||
use tower_sessions_sqlx_store::SqliteStore;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum InfraSessionStore {
|
||||
#[cfg(feature = "sqlite")]
|
||||
Sqlite(SqliteStore),
|
||||
#[cfg(feature = "postgres")]
|
||||
Postgres(PostgresStore),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionStore for InfraSessionStore {
|
||||
async fn save(&self, session_record: &Record) -> tower_sessions::session_store::Result<()> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlite")]
|
||||
Self::Sqlite(store) => store.save(session_record).await,
|
||||
#[cfg(feature = "postgres")]
|
||||
Self::Postgres(store) => store.save(session_record).await,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(tower_sessions::session_store::Error::Backend(
|
||||
"No backend enabled".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn load(&self, session_id: &Id) -> tower_sessions::session_store::Result<Option<Record>> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlite")]
|
||||
Self::Sqlite(store) => store.load(session_id).await,
|
||||
#[cfg(feature = "postgres")]
|
||||
Self::Postgres(store) => store.load(session_id).await,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(tower_sessions::session_store::Error::Backend(
|
||||
"No backend enabled".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete(&self, session_id: &Id) -> tower_sessions::session_store::Result<()> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlite")]
|
||||
Self::Sqlite(store) => store.delete(session_id).await,
|
||||
#[cfg(feature = "postgres")]
|
||||
Self::Postgres(store) => store.delete(session_id).await,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(tower_sessions::session_store::Error::Backend(
|
||||
"No backend enabled".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InfraSessionStore {
|
||||
pub async fn migrate(&self) -> Result<(), sqlx::Error> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlite")]
|
||||
Self::Sqlite(store) => store.migrate().await,
|
||||
#[cfg(feature = "postgres")]
|
||||
Self::Postgres(store) => store.migrate().await,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => Err(sqlx::Error::Configuration("No backend enabled".into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use k_core::session::store::InfraSessionStore;
|
||||
|
||||
@@ -1,101 +1,45 @@
|
||||
use async_trait::async_trait;
|
||||
use k_core::ai::qdrant::QdrantAdapter as CoreQdrant;
|
||||
use notes_domain::errors::{DomainError, DomainResult};
|
||||
use notes_domain::ports::VectorStore;
|
||||
use qdrant_client::Qdrant;
|
||||
use qdrant_client::qdrant::{
|
||||
CreateCollectionBuilder, Distance, PointStruct, SearchPointsBuilder, UpsertPointsBuilder,
|
||||
Value, VectorParamsBuilder,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct QdrantVectorAdapter {
|
||||
client: Arc<Qdrant>,
|
||||
collection_name: String,
|
||||
inner: Arc<CoreQdrant>,
|
||||
}
|
||||
|
||||
impl QdrantVectorAdapter {
|
||||
pub fn new(url: &str, collection_name: &str) -> DomainResult<Self> {
|
||||
let client = Qdrant::from_url(url).build().map_err(|e| {
|
||||
DomainError::InfrastructureError(format!("Failed to create Qdrant client: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(client),
|
||||
collection_name: collection_name.to_string(),
|
||||
})
|
||||
pub fn new(inner: CoreQdrant) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_collection_if_not_exists(&self) -> DomainResult<()> {
|
||||
if !self
|
||||
.client
|
||||
.collection_exists(&self.collection_name)
|
||||
pub async fn init(&self) -> DomainResult<()> {
|
||||
self.inner
|
||||
.create_collection_if_not_exists(384)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DomainError::InfrastructureError(format!(
|
||||
"Failed to check collection existence: {}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
{
|
||||
self.client
|
||||
.create_collection(
|
||||
CreateCollectionBuilder::new(self.collection_name.clone())
|
||||
.vectors_config(VectorParamsBuilder::new(384, Distance::Cosine)),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
DomainError::InfrastructureError(format!("Failed to create collection: {}", e))
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
.map_err(|e| DomainError::InfrastructureError(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VectorStore for QdrantVectorAdapter {
|
||||
async fn upsert(&self, id: Uuid, vector: &[f32]) -> DomainResult<()> {
|
||||
let payload: HashMap<String, Value> = HashMap::new();
|
||||
let payload = HashMap::new();
|
||||
|
||||
let point = PointStruct::new(id.to_string(), vector.to_vec(), payload);
|
||||
|
||||
let upsert_points = UpsertPointsBuilder::new(self.collection_name.clone(), vec![point]);
|
||||
|
||||
self.client
|
||||
.upsert_points(upsert_points)
|
||||
self.inner
|
||||
.upsert(id, vector.to_vec(), payload)
|
||||
.await
|
||||
.map_err(|e| DomainError::InfrastructureError(format!("Qdrant upsert error: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
.map_err(|e| DomainError::InfrastructureError(format!("Qdrant upsert error: {}", e)))
|
||||
}
|
||||
|
||||
async fn find_similar(&self, vector: &[f32], limit: usize) -> DomainResult<Vec<(Uuid, f32)>> {
|
||||
let search_points =
|
||||
SearchPointsBuilder::new(self.collection_name.clone(), vector.to_vec(), limit as u64)
|
||||
.with_payload(true);
|
||||
|
||||
let search_result = self
|
||||
.client
|
||||
.search_points(search_points)
|
||||
self.inner
|
||||
.search(vector.to_vec(), limit as u64)
|
||||
.await
|
||||
.map_err(|e| DomainError::InfrastructureError(format!("Qdrant search error: {}", e)))?;
|
||||
|
||||
let results = search_result
|
||||
.result
|
||||
.into_iter()
|
||||
.filter_map(|point| {
|
||||
let id = point.id?;
|
||||
let uuid_str = match id.point_id_options? {
|
||||
qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let uuid = Uuid::parse_str(&uuid_str).ok()?;
|
||||
Some((uuid, point.score))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
.map_err(|e| DomainError::InfrastructureError(format!("Qdrant search error: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ use futures_util::StreamExt;
|
||||
use notes_domain::services::SmartNoteService;
|
||||
#[cfg(feature = "smart-features")]
|
||||
use notes_infra::factory::{
|
||||
BrokerProvider, build_database_pool, build_embedding_generator, build_link_repository,
|
||||
build_message_broker, build_vector_store,
|
||||
BrokerProvider, build_embedding_generator, build_link_repository, build_message_broker,
|
||||
build_vector_store,
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
@@ -31,7 +31,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
.expect("Message broker required for worker");
|
||||
|
||||
let db_config = DatabaseConfig::new(config.database_url.clone());
|
||||
let db_pool = build_database_pool(&db_config).await?;
|
||||
let db_pool = k_core::db::connect(&db_config).await?;
|
||||
|
||||
// Initialize smart feature adapters
|
||||
let embedding_generator = build_embedding_generator(&config.embedding_provider).await?;
|
||||
|
||||
Reference in New Issue
Block a user