puppy/lib/store/src/mixin.rs

251 lines
7.6 KiB
Rust

use std::ops::RangeBounds;
use bincode::{Decode, Encode};
use chrono::{DateTime, Utc};
use super::{
types::{DataType, MixinSpec},
Batch, Store, Transaction,
};
use crate::{util::IterExt as _, Error, Key, Result};
/// Mixins are the simplest pieces of data in the store.
pub trait Mixin: DataType<Type = MixinSpec> + Encode + Decode {}
/// Derive a [`Mixin`] implementation.
///
/// In addition to deriving `Mixin`, you will need to derive or implement [`Encode`]
/// and [`Decode`].
pub use r#macro::Mixin;
impl Store {
/// Get the value!
pub fn get_mixin<M>(&self, node: Key) -> Result<Option<M>>
where
M: Mixin,
{
op::get_mixin(self, node)
}
/// Check if `node` has a mixin `M`.
pub fn has_mixin<M>(&self, node: Key) -> Result<bool>
where
M: Mixin,
{
op::has_mixin::<M>(self, node)
}
/// Get all `M`s where the key's timestamp is within the `range`.
pub fn range<M>(
&self,
range: impl RangeBounds<DateTime<Utc>>,
) -> impl Iterator<Item = Result<(Key, M)>> + '_
where
M: Mixin,
{
op::get_range(self, range)
}
/// Think "LEFT JOIN". In goes an iterator over keys, out come all the associated results.
pub fn join_on<M>(
&self,
iter: impl IntoIterator<Item = Result<Key>>,
) -> Result<Vec<(Key, Option<M>)>>
where
M: Mixin,
{
op::join_on(self, iter)
}
}
impl Transaction<'_> {
/// Apply an update function to the mixin `M` of `node`.
///
/// # Errors
///
/// - [`Error::Missing`]: if `node` does not have a mixin of this type.
///
/// [`Error::Missing`]: crate::Error::Missing
pub fn update<M>(&self, node: Key, update: impl FnOnce(M) -> M) -> Result<()>
where
M: Mixin,
{
op::update(self, node, update)
}
/// Get the mixin of the specified type associated with `node`.
pub fn get_mixin<M>(&self, node: Key) -> Result<Option<M>>
where
M: Mixin,
{
op::get_mixin(self, node)
}
/// Add a mixin to `node`.
///
/// # Errors
///
/// - [`Error::Conflict`]: if `node` already has a mixin of type `M`.
///
/// [`Error::Conflict`]: crate::Error::Missing
pub fn add_mixin<M>(&self, node: Key, mixin: M) -> Result<()>
where
M: Mixin,
{
if op::has_mixin::<M>(self, node)? {
return Err(Error::Conflict);
} else {
op::add_mixin::<M>(self, node, mixin)
}
}
/// Check whether `node` has an `M` defined for it.
pub fn has_mixin<M>(&self, node: Key) -> Result<bool>
where
M: Mixin,
{
op::has_mixin::<M>(self, node)
}
/// Get all `M`s where the key's timestamp is within the `range`.
pub fn range<M>(
&self,
range: impl RangeBounds<DateTime<Utc>>,
) -> impl Iterator<Item = Result<(Key, M)>> + '_
where
M: Mixin,
{
op::get_range(self, range)
}
/// Think "LEFT JOIN". In goes an iterator over keys, out come all the associated results.
pub fn join_on<M, T>(
&self,
f: impl Fn(T) -> Key,
iter: impl IntoIterator<Item = Result<T>>,
) -> Result<Vec<(Key, Option<M>)>>
where
M: Mixin,
{
op::join_on(self, iter.into_iter().map_ok(f))
}
}
impl Batch {
/// Add a mixin to the `node`.
///
/// **Note**: unlike [`Transaction::add_mixin`], this will *not* return an error if the key already has a mixin
/// of this type. This *should* not cause inconsistency.
pub fn put_mixin<M>(&mut self, node: Key, mixin: M)
where
M: Mixin,
{
op::add_mixin(self, node, mixin).unwrap()
}
}
mod op {
use std::ops::{Bound, RangeBounds};
use chrono::{DateTime, TimeDelta, Utc};
use either::Either;
use super::Mixin;
use crate::{internal::*, util::IterExt as _, Error, Key, Result};
pub fn update<M>(
cx: &(impl Query + Write),
node: Key,
update: impl FnOnce(M) -> M,
) -> Result<()>
where
M: Mixin,
{
// TODO: implement in terms of a merge operator instead of separate query and write ops.
// this would let us remove the `Query` bound, which would in turn let us update from within
// a batch.
//
// See https://github.com/facebook/rocksdb/wiki/Merge-Operator
//
// It looks like rocksdb allows you to specify a merge operator per column family.[^1]
// This means we can construct our column families with a merge operator that knows how to encode and decode mixins.
//
// [^1]: https://github.com/facebook/rocksdb/blob/9d37408f9af15c7a1ae42f9b94d06b27d98a011a/include/rocksdb/options.h#L128
let tree = cx.open(M::SPEC.keyspace);
match tree.get(node.as_ref())? {
None => Err(Error::Missing),
Some(buf) => {
let new = decode(buf).map(update).and_then(encode)?;
tree.set(node, new)
}
}
}
pub fn get_mixin<M: Mixin>(cx: &impl Query, node: Key) -> Result<Option<M>> {
cx.open(M::SPEC.keyspace).get(node)?.map(decode).transpose()
}
pub fn add_mixin<M: Mixin>(cx: &impl Write, node: Key, mixin: M) -> Result<()> {
cx.open(M::SPEC.keyspace).set(node, encode(mixin)?)
}
pub fn has_mixin<M: Mixin>(cx: &impl Query, node: Key) -> Result<bool> {
cx.open(M::SPEC.keyspace).has(node)
}
pub fn get_range<M: Mixin>(
cx: &impl Query,
range: impl RangeBounds<DateTime<Utc>>,
) -> impl Iterator<Item = Result<(Key, M)>> + '_ {
// TODO: Test this thoroughly
const MS: TimeDelta = TimeDelta::milliseconds(1);
let iter = match (range.start_bound(), range.end_bound()) {
(Bound::Unbounded, Bound::Unbounded) => Either::Left(cx.open(M::SPEC.keyspace).list()),
(min, max) => {
let lower = match min {
Bound::Unbounded => [u8::MIN; 16],
Bound::Included(inc) => Key::range(*inc).0,
Bound::Excluded(exc) => Key::range(*exc + MS).0,
};
let upper = match max {
Bound::Unbounded => [u8::MAX; 16],
Bound::Included(inc) => Key::range(*inc).1,
Bound::Excluded(exc) => Key::range(*exc - MS).1,
};
Either::Right(cx.open(M::SPEC.keyspace).range(lower, upper))
}
};
iter.bind_results(|(k, v)| {
let key = Key::from_slice(k.as_ref());
let val = decode(v)?;
Ok((key, val))
})
}
pub fn join_on<M>(
cx: &impl Query,
iter: impl IntoIterator<Item = Result<Key>>,
) -> Result<Vec<(Key, Option<M>)>>
where
M: Mixin,
{
let keys: Vec<Key> = iter.into_iter().try_collect()?;
cx.open(M::SPEC.keyspace)
.join(keys.iter())
.into_iter()
.zip(keys)
.map(|(opt, key)| {
let Some(buf) = opt? else {
return Ok((key, None));
};
let val = decode(buf)?;
Ok((key, Some(val)))
})
.try_collect()
}
pub(super) fn encode(data: impl bincode::Encode) -> Result<Vec<u8>> {
bincode::encode_to_vec(data, bincode::config::standard()).map_err(Error::Encoding)
}
pub(super) fn decode<T>(data: impl AsRef<[u8]>) -> Result<T>
where
T: bincode::Decode,
{
bincode::decode_from_slice(data.as_ref(), bincode::config::standard())
.map_err(Error::Decoding)
.map(|(v, _)| v)
}
}