diff --git a/Cargo.lock b/Cargo.lock index a9d9ffe..12a86d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -518,6 +518,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -525,6 +540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -533,6 +549,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -551,10 +595,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1258,6 +1308,7 @@ dependencies = [ "derive_more", "either", "fetch", + "futures", "serde_json", "store", "tracing", diff --git a/lib/fetch/src/client.rs b/lib/fetch/src/client.rs index f3d628f..ec88149 100644 --- a/lib/fetch/src/client.rs +++ b/lib/fetch/src/client.rs @@ -38,13 +38,13 @@ impl Client { /// /// Note that in order for the request to be considered valid by most implementations, `key.owner` /// must equal `payload.actor`. - #[instrument(target = "fetch.delivery", skip_all, fields(activity = payload.id, url = inbox, key = key.id))] - pub async fn deliver(&self, key: &SigningKey, payload: &Activity, inbox: &str) { + #[instrument(target = "fetch.delivery", skip_all, fields(activity = payload.id, url = inbox.as_ref(), key = key.id))] + pub async fn deliver(&self, key: &SigningKey, payload: &Activity, inbox: impl AsRef) { let system = Subsystem::Delivery; let body = serde_json::to_string(&payload.to_json_ld()).unwrap(); let mut req = system - .new_request(inbox) + .new_request(inbox.as_ref()) .unwrap() .method(Method::POST) .header("content-type", ACTIVITYPUB_TYPE) diff --git a/lib/puppy/Cargo.toml b/lib/puppy/Cargo.toml index acef129..942f391 100644 --- a/lib/puppy/Cargo.toml +++ b/lib/puppy/Cargo.toml @@ -14,3 +14,4 @@ either = "*" derive_more = "*" serde_json = "*" tracing = "*" +futures = "*" diff --git a/lib/puppy/src/lib.rs b/lib/puppy/src/lib.rs index 4e885a5..e2c7852 100644 --- a/lib/puppy/src/lib.rs +++ b/lib/puppy/src/lib.rs @@ -2,7 +2,13 @@ //! you should take a look at [`fetch`]. // Working with result types is such a bitch without these. -#![feature(iterator_try_collect, try_blocks, once_cell_try, box_into_inner)] +#![feature( + iterator_try_collect, + try_blocks, + once_cell_try, + box_into_inner, + type_changing_struct_update +)] use std::hint::unreachable_unchecked; @@ -336,7 +342,7 @@ pub async fn ingest(cx: &Context, auth: Key, activity: &Activity) -> Result<()> match object { Object::Activity(a) => interpret(&cx, a)?, Object::Actor(a) => cx.run(|tx| actor::create_remote(tx, a).map(void))?, - Object::Note(a) => post::create_post_from_note(cx, a).map(void)?, + Object::Note(a) => cx.run(|tx| post::create_post_from_note(tx, a).map(void))?, _ => todo!(), } } @@ -345,3 +351,613 @@ pub async fn ingest(cx: &Context, auth: Key, activity: &Activity) -> Result<()> /// Discard the argument. fn void(_: T) -> () {} + +pub mod systems { + //! Logic containment zone. + + use fetch::Client; + use store::{Key, Transaction}; + use tracing::warn; + + /// Allows subsystems to interact with the [fetch] and [store] components. + pub trait Context { + /// Access the transaction. + fn db(&self) -> &Transaction<'_>; + /// Access the federation client. + fn client(&self) -> &Client; + /// Get the ActivityPub domain of this server. + fn domain(&self) -> &str; + /// Format `key` as an object ID. + fn make_object_id(&self, key: Key) -> String { + format!("https://{}/o/{key}", self.domain()) + } + /// Check whether we already know about an object. + fn is_known(&self, url: &str) -> bool { + use crate::data::Id; + + match self.db().lookup(Id(url.to_string())) { + Ok(Some(_)) => true, + _ => false, + } + } + /// Check whether `url` refers to an object that is considered "local" to this server. + fn is_local(&self, url: &str) -> bool { + use crate::data::{Id, Object}; + + let Ok(Some(key)) = self.db().lookup(Id(url.to_string())) else { + return false; + }; + + match self.db().get_mixin(key) { + Ok(Some(Object { local, .. })) => local, + Ok(None) => false, + Err(err) => { + warn!("error while trying to determine origin of {key}: {err}"); + false + } + } + } + } + + pub mod notification { + //! Manages the delivery of notifications to local users. + + use tracing::info; + + use crate::Result; + use super::Context; + + /// Something that can be sent to a local user as a (push) notification. + pub trait Notification {} + + /// Get the notification where it needs to go. + pub fn dispatch(cx: &C, event: &impl Notification) -> Result<()> + where + C: Context, + { + info!("ding!"); + Ok(()) + } + } + + pub mod delivery { + //! A mid-level system for delivering messages to remote servers. + + use fetch::{ + object::{Activity, Object}, + signatures::SigningKey, + }; + use futures::{stream::FuturesUnordered, Future, StreamExt as _}; + use derive_more::Display; + use store::{arrow::Multi, Key}; + use tracing::error; + + use crate::{entities::Accept, Error, Result}; + + use super::Context; + + /// Type tag for an activity. + #[derive(Clone, Copy, Display)] + pub enum Tag { + Follow, + Accept, + Reject, + Create, + Delete, + Bite, + } + + /// Implemented for types that can represent an activity that can be delivered to a remote server. + pub trait Payload { + /// Construct an [`Activity`] for delivery. + fn prepare(self, cx: &C) -> Result + where + C: Context + ?Sized; + /// Call the delivery subsystem to this payload everywhere it needs to go. + /// + /// A convenience method to call [`deliver`] as a method instead of a free-standing function. + fn deliver(self, cx: &C) -> impl Future> + Send + Sync + where + Self: Sized + Send + Sync, + C: Context + Send + Sync + ?Sized, + { + deliver(cx, self) + } + } + + impl Payload for Accept { + fn prepare(self, cx: &C) -> Result + where + C: Context + ?Sized, + { + let (actor, object, tag) = get_activity_data(cx, self.into())?; + Ok(Activity { + id: cx.make_object_id(self.into()), + actor, + object: Box::new(Object::Id { id: object }), + kind: tag.to_string(), + }) + } + } + + /// Deliver the activity-like payload to all its intended recipients. + pub async fn deliver(cx: &C, payload: impl Payload) -> Result<()> + where + C: Context + ?Sized, + { + let activity = payload.prepare(cx)?; + + if !cx.is_local(&activity.id) { + error!("delivery of non-local activity!!"); + return Ok(()); + } + + let signing_key = get_keypair(cx, &activity.actor)?; + get_targets(cx, &activity)? + .into_iter() + .map(|inbox| cx.client().deliver(&signing_key, &activity, inbox)) + .collect::>() + .collect::<()>() + .await; + + Ok(()) + } + + /// Calculate the list of inboxes to send the activity to. + fn get_targets(cx: &C, activity: &Activity) -> Result> + where + C: Context + ?Sized, + { + use crate::data::{Channel, Id}; + + let get_inbox = |id: &str| -> Result> { + let Some(key) = cx.db().lookup(Id(id.to_owned()))? else { + return Ok(None); + }; + cx.db() + .get_mixin(key) + .map(|o| o.map(|Channel { inbox }| inbox)) + .map_err(Error::Store) + }; + + match *activity.object { + Object::Id { ref id } => get_inbox(&id), + Object::Activity(ref a) => get_inbox(&a.actor), + Object::Actor(ref a) => Ok(Some(a.inbox.clone())), + Object::Note(ref a) => get_inbox(&a.author), + } + } + + /// Get a signing key for the given actor id. + /// + /// Will fail if the actor isn't local. + fn get_keypair(cx: &C, actor_id: &str) -> Result + where + C: Context + ?Sized, + { + use crate::{ + actor::{get_signing_key, Actor}, + data::Id, + }; + let Some(actor_key) = cx.db().lookup(Id(actor_id.to_string()))? else { + panic!("could not get db key for {actor_id}"); + }; + get_signing_key(cx.db(), Actor { key: actor_key }).map_err(Error::Store) + } + + /// Get actor id, the object id, and the type of the activity. + fn get_activity_data(cx: &C, key: Key) -> Result<(String, String, Tag)> + where + C: Context + ?Sized, + { + // TODO: instead of panicking, return a normal error. + use crate::data::{Id, Object, ActivityKind, ObjectKind}; + + let Some(Object { kind, .. }) = cx.db().get_mixin(key)? else { + panic!("activity must be an object") + }; + let Some(Multi { origin, target, .. }) = cx.db().get_arrow_raw(key)? else { + panic!("expected activity to have multi-arrow repr") + }; + let Some(Id(object)) = cx.db().get_alias(target)? else { + panic!("expected object of activity to have an id") + }; + let Some(Id(actor)) = cx.db().get_alias(origin)? else { + panic!("expected author of activity to have an id") + }; + + let tag = match kind { + ObjectKind::Activity(tag) => match tag { + ActivityKind::Create => Tag::Create, + ActivityKind::Follow => Tag::Follow, + ActivityKind::Accept => Tag::Accept, + ActivityKind::Reject => Tag::Reject, + ActivityKind::Bite => Tag::Bite, + }, + _ => panic!("invalid kind for activity"), + }; + + Ok((actor, object, tag)) + } + } + + pub mod ingestion { + //! Processing remote objects to make them part of our local graph. + } + + pub mod processor { + //! The structure and interpretation of activity graphs. + //! + //! This is where the magic happens. + //! + //! This module contains the code that assembles and executes a graph of activities. This is a higher + //! level system, that depends on the following sibling systems: + //! + //! - [`following`][super::following] + //! - [`delivery`][super::delivery] + //! + //! Its main purpose is to integrate all the above system in the context of receiving an activity in + //! an inbox, and the subsequent processing that needs to happen. In order to *correctly* interpret an + //! activity, we need to know about the other nodes that it references. + + use fetch::{ + object::{Activity, Actor, Note, Object}, + signatures::SigningKey, + FetchError, + }; + use store::Key; + use tracing::{debug, info, trace, warn}; + + use crate::{data, Result}; + + use super::{ + delivery::{self, Tag}, + following, Context, + }; + + /// Process an incoming activity. + #[tracing::instrument(target = "puppy.processor", skip_all, fields(activity = root.id))] + pub async fn process_incoming( + cx: &C, + on_behalf_of: &SigningKey, + root: Activity, + ) -> Result<(Key, Tag)> + where + C: Context, + { + if cx.is_local(&root.id) || cx.is_known(&root.id) { + panic!("could not process activity, it already exists"); + } + + // Fetch all transitive dependencies of `root`, returned in the order that they need to be processed. + let (actors, notes, activities) = + fetch_dependencies(cx, &root, &on_behalf_of, 3).await?; + + trace!("storing actors"); + for actor in actors { + store_actor(cx, actor)?; + } + + trace!("storing notes"); + for note in notes { + store_note(cx, note)?; + } + + // Process the activities that have to execute before `activity`. These dependencies are returned + // in the exact order that they need to be applied. `fetch_dependencies` only returns stuff that is + // missing from the graph. + trace!("applying dependency activities"); + for dep in activities { + let (key, tag) = apply_activity(cx, dep).await?; + debug!( + dependency_of = root.id, + kind = tag.to_string(), + "processed activity {key}" + ); + } + + // Now that all activities that this depends on have been executed, we can execute the original + // one. + trace!("executing target activity"); + apply_activity(cx, root).await + } + + /// Execute the activity. + /// + /// # Panics + /// + /// This function assumes that everything the activity references by ActivityPub ID is already present in the context. + /// If this is not the case, it will panic. + /// + /// Specifically, + #[tracing::instrument(level = "TRACE", target = "puppy.processor", skip_all, fields(activity = activity.id))] + async fn apply_activity(cx: &C, activity: Activity) -> Result<(Key, Tag)> + where + C: Context, + { + use crate::data::{Id, Create}; + // Get a key and error out if it does not exist + let get_key = |url: &str| -> Key { + cx.db() + .lookup(data::Id(url.to_string())) + .expect("database should be operable") + .expect("url should already have been inserted") + }; + + let tag = tagof(&activity); + let key = match tag { + Tag::Follow => { + let requester = get_key(&activity.actor); + let target = get_key(activity.object.id()); + + let req = following::create_follow_request(cx, requester, target)?; + debug!("created follow request {req}"); + + // For now, automatically accept follow requests. + if cx.is_local(activity.object.id()) { + debug!("auto-accepting follow request for local actor {target}",); + let accept = following::accept_follow_request(cx, req)?; + if !cx.is_local(&activity.actor) { + debug!("delivering to remote actor"); + delivery::deliver(cx, accept).await?; + } + } + + req.into() + } + Tag::Accept => { + let req_id = get_key(activity.object.id()); + following::accept_follow_request(cx, req_id.into())?.into() + } + Tag::Reject => { + let req_id = get_key(activity.object.id()); + following::reject_follow_request(cx, req_id.into())?.into() + } + Tag::Create => { + let id = Key::gen(); + let actor = get_key(&activity.actor); + let object = get_key(activity.object.id()); + cx.db().add_alias(id, Id(cx.make_object_id(id)))?; + cx.db().create(Create { id, actor, object })?; + id.into() + } + Tag::Delete => todo!(), + Tag::Bite => todo!(), + }; + + Ok((key, tag)) + } + + /// Fetch all transitive dependencies of `root` using the given `auth` as the signing authority. + /// + /// The `budget` parameter specifies the maximum number of recursive calls. + #[tracing::instrument(target = "puppy.processor", skip_all, fields(budget = budget, target = root.id))] + async fn fetch_dependencies( + cx: &C, + root: &Activity, + auth: &SigningKey, + budget: usize, + ) -> Result<(Vec, Vec, Vec)> + where + C: Context, + { + let mut actors = Vec::new(); + let mut notes = Vec::new(); + let mut activities = Vec::new(); + + for url in [root.actor.as_str(), root.object.id()] { + if cx.is_known(url) { + debug!(parent = root.id, url, "already known, skipping"); + continue; + } else { + debug!(parent = root.id, url, budget, "fetching dependency"); + } + + let json = cx.client().resolve(auth, url).await?; + let object = Object::from_json(json).map_err(FetchError::BadJson)?; + + match object { + Object::Id { id } => { + warn!(parent = root.id, url, "could not fetch {id}"); + } + Object::Activity(a) if budget == 0 => { + info!(parent = root.id, url, "exceeded budget, skipping"); + activities.push(a); + continue; + } + Object::Activity(a) => { + debug!(parent = root.id, url, "fetching dependencies"); + // BUG: this won't hit cache because none of the stuff is in there yet. + let (x, y, z) = + Box::pin(fetch_dependencies(cx, &a, auth, budget - 1)).await?; + trace!( + parent = root.id, + url, + "fetch completed, total: {}", + x.len() + y.len() + z.len() + ); + actors.extend(x); + notes.extend(y); + activities.extend(z); + activities.push(a); + } + Object::Actor(a) => actors.push(a), + Object::Note(a) => notes.push(a), + } + } + + Ok((actors, notes, activities)) + } + + fn store_actor(cx: &C, actor: Actor) -> Result<()> + where + C: Context, + { + if !cx.is_known(&actor.id) { + crate::actor::create_remote(cx.db(), actor)?; + } else { + debug!("actor {} is already known", actor.id); + } + Ok(()) + } + + fn store_note(cx: &C, note: Note) -> Result<()> + where + C: Context, + { + if !cx.is_known(¬e.id) { + crate::post::create_post_from_note(cx.db(), note)?; + } else { + debug!("note {} is already known", note.id); + } + Ok(()) + } + + /// Get the tag of the activity. + fn tagof(a: &Activity) -> Tag { + match a.kind.as_str() { + "Create" => Tag::Create, + "Delete" => Tag::Delete, + "Follow" => Tag::Follow, + "Accept" => Tag::Accept, + "Reject" => Tag::Reject, + "Bite" => Tag::Bite, + _ => todo!(), + } + } + } + + pub mod timelines { + //! Manages the rendering of timelines. + } + + pub mod following { + //! Follow requests and management thereof. + + use store::Key; + + use super::Context; + use crate::{ + entities::{Accept, FollowRequest, Reject, Undo}, + Result, + }; + + /// Create a follow request. + pub fn create_follow_request(cx: &C, follower: Key, target: Key) -> Result + where + C: Context + ?Sized, + { + Ok(Key::gen().into()) + } + + /// Cancel a follow request. + /// + /// If the follow request was already accepted, the follow request's target is unfollowed by the actor. Otherwise, + /// the follow request is withdrawn. + /// + /// This creates a new [`Undo`] entry in the database to which data may be attached. + pub fn cancel_follow_request(cx: &C, req: FollowRequest) -> Result + where + C: Context + ?Sized, + { + Ok(Key::gen().into()) + } + + /// Apply the changes related to accepting a follow request to the social graph and create a new node representing + /// the event. + pub fn accept_follow_request(cx: &C, req: FollowRequest) -> Result + where + C: Context + ?Sized, + { + Ok(Key::gen().into()) + } + + /// Apply the changes related to rejecting a follow request to the social graph and create a new node representing + /// the event. + pub fn reject_follow_request(cx: &C, req: FollowRequest) -> Result + where + C: Context + ?Sized, + { + Ok(Key::gen().into()) + } + } + + pub mod reverse { + //! Undoing operations. + //! + //! This module defines the behavior of the [`Undo`] activity. + + use crate::entities::{FollowRequest, Undo}; + + use super::{following, Context}; + + /// Describes objects which have a "revert" operation defined (that is, they can be the target of an [`Undo`] activity). + pub trait Reversible { + /// Undo `self` and generate a corresponding [`Undo`] object recording this fact. + fn revert(&self, cx: &dyn Context) -> crate::Result; + } + + impl Reversible for FollowRequest { + /// Withdraw a follow request if it wasn't yet accepted, or unfollow someone. + fn revert(&self, cx: &dyn Context) -> crate::Result { + following::cancel_follow_request(cx, *self) + } + } + } +} + +pub mod entities { + //! Virtual data types composed from [components] and operated on by [systems]. + //! + //! [components]: crate::components + //! [systems]: crate::systems + + use store::Key; + use derive_more::{From, Into, Display}; + + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Poster(Key); + + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Server(Key); + + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Post(Key); + + /// Represents a `Bite` activity. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Bite(Key); + + /// Represents an `Undo` activity. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Undo(Key); + + /// Represents an `Accept` activity. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Accept(Key); + + /// Represents a `Reject` activity. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Reject(Key); + + /// Represents a `Create` activity. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct Create(Key); + + /// Represents a `Follow` activity. + /// + /// Also see the [`following`][crate::systems::following] module, which defines the logic for following, follow requests + /// and other related stuff. + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct FollowRequest(Key); + + #[derive(Clone, Copy, Debug, From, Into, Display)] + pub struct PublicKey(Key); + + /// A key newtype that represents an object. + pub trait Entity: Into + From + Copy {} + + impl Entity for T where T: Into + From + Copy {} +} + +pub mod components {} diff --git a/lib/puppy/src/post.rs b/lib/puppy/src/post.rs index b6a8bc1..2ca936e 100644 --- a/lib/puppy/src/post.rs +++ b/lib/puppy/src/post.rs @@ -152,29 +152,27 @@ pub fn create_local_post( } /// Assumes all objects referenced already exist. -#[tracing::instrument(skip(cx))] -pub fn create_post_from_note(cx: &Context, note: Note) -> crate::Result { - cx.run(|tx| { - let Some(author) = tx.lookup(Id(note.author))? else { - panic!("needed author to already exist") - }; +#[tracing::instrument(skip(tx))] +pub fn create_post_from_note(tx: &Transaction<'_>, note: Note) -> crate::Result { + let Some(author) = tx.lookup(Id(note.author))? else { + panic!("needed author to already exist") + }; - let key = Key::gen(); + let key = Key::gen(); - tx.add_alias(key, Id(note.id.clone()))?; - tx.create(AuthorOf { object: key, author })?; - tx.add_mixin(key, Content { - content: note.content, - warning: note.summary, - })?; - tx.add_mixin(key, data::Object { - kind: ObjectKind::Notelike(note.kind), - id: Id(note.id), - local: false, - })?; + tx.add_alias(key, Id(note.id.clone()))?; + tx.create(AuthorOf { object: key, author })?; + tx.add_mixin(key, Content { + content: note.content, + warning: note.summary, + })?; + tx.add_mixin(key, data::Object { + kind: ObjectKind::Notelike(note.kind), + id: Id(note.id), + local: false, + })?; - Ok(Post { key }) - }) + Ok(Post { key }) } #[tracing::instrument(skip(cx))] diff --git a/lib/store/src/arrow.rs b/lib/store/src/arrow.rs index 96bf90e..89696a9 100644 --- a/lib/store/src/arrow.rs +++ b/lib/store/src/arrow.rs @@ -222,11 +222,15 @@ impl Transaction<'_> { where A: Arrow, { + Ok(self.get_arrow_raw(key)?.map(A::from)) + } + /// Construct the arrow from its identifier. + pub fn get_arrow_raw(&self, key: Key) -> Result> { let arrow = self .open(crate::types::MULTIEDGE_HEADERS) .get(key)? .map(|v| Key::split(v.as_ref())) - .map(|(origin, target)| A::from(Multi { origin, target, identity: key })); + .map(|(origin, target)| Multi { origin, target, identity: key }); Ok(arrow) } } diff --git a/lib/store/src/key.rs b/lib/store/src/key.rs index 92ef0d1..1ebca3f 100644 --- a/lib/store/src/key.rs +++ b/lib/store/src/key.rs @@ -3,13 +3,14 @@ use std::{ str::FromStr, }; +use bincode::{Decode, Encode}; use chrono::{DateTime, Utc}; use ulid::Ulid; use crate::StoreError; /// A unique identifier for vertices in the database. -#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Encode, Decode)] pub struct Key(pub(crate) [u8; 16]); impl Key { diff --git a/lib/store/src/lib.rs b/lib/store/src/lib.rs index 67e9808..04288d7 100644 --- a/lib/store/src/lib.rs +++ b/lib/store/src/lib.rs @@ -68,7 +68,7 @@ pub struct Batch { } impl Store { - /// Run a [transaction][Transaction]. + /// Run a [transaction][Transaction] and ensure that it is either committed or rolled back. /// /// In a transaction, either all writes succeed, or the transaction is aborted and the changes are not /// recorded. Changes made inside a transaction can be read from within that transaction before they are @@ -85,15 +85,20 @@ impl Store { store: &self, }; let r = f(&tx); - if let Err(e) = if r.is_err() { - tx.inner.rollback() + if r.is_err() { + tx.cancel()?; } else { - tx.inner.commit() - } { - return Err(E::from(StoreError::Internal(e))); + tx.commit()?; } r } + /// Begin a transaction. + pub fn start(&self) -> Transaction<'_> { + Transaction { + inner: self.inner.transaction(), + store: &self, + } + } /// Apply a batch of changes atomically. pub fn apply(&self, batch: Batch) -> Result<()> { self.inner.write(batch.inner.into_inner())?; @@ -132,6 +137,17 @@ impl Store { } } +impl<'db> Transaction<'db> { + /// Complete the transaction successfully. + pub fn commit(self) -> Result<(), StoreError> { + self.inner.commit().map_err(StoreError::from) + } + /// Cancel the transaction. + pub fn cancel(self) -> Result<(), StoreError> { + self.inner.rollback().map_err(StoreError::from) + } +} + /// A shorthand for committing a [`Transaction`] (because I think `Ok(())` is ugly). pub const OK: Result<()> = Ok(());