{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}

{- |
Module      : Hindsight.Store.PostgreSQL.Events.Concurrency
Description : Version validation for PostgreSQL event store
Copyright   : (c) 2024
License     : BSD3
Maintainer  : maintainer@example.com
Stability   : internal

This module implements optimistic concurrency control for the PostgreSQL
backend by validating version expectations before event insertion.

Version checks use row-level locking to ensure consistency while minimizing
contention between concurrent writers to different streams.
-}
module Hindsight.Store.PostgreSQL.Events.Concurrency (
    checkVersions,
)
where

import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Profunctor (dimap)
import Data.UUID (UUID)
import Hasql.Decoders qualified as D
import Hasql.Decoders qualified as Decoders
import Hasql.Encoders qualified as E
import Hasql.Encoders qualified as Encoders
import Hasql.Statement qualified as Statement
import Hasql.TH (maybeStatement, singletonStatement)
import Hasql.Transaction qualified as HasqlTransaction
import Hindsight.Events (SomeLatestEvent)
import Hindsight.Store
import Hindsight.Store.PostgreSQL.Core.Types

-- | Get current version of a stream (global cursor)
getCurrentVersionStatement :: Statement.Statement UUID (Maybe (Cursor SQLStore))
getCurrentVersionStatement :: Statement UUID (Maybe (Cursor SQLStore))
getCurrentVersionStatement =
    (UUID -> UUID)
-> (Maybe (Int64, Int32) -> Maybe (Cursor SQLStore))
-> Statement UUID (Maybe (Int64, Int32))
-> Statement UUID (Maybe (Cursor SQLStore))
forall a b c d.
(a -> b) -> (c -> d) -> Statement b c -> Statement a d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap UUID -> UUID
forall a. a -> a
id (((Int64, Int32) -> SQLCursor)
-> Maybe (Int64, Int32) -> Maybe SQLCursor
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Int64 -> Int32 -> SQLCursor) -> (Int64, Int32) -> SQLCursor
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int64 -> Int32 -> SQLCursor
SQLCursor)) (Statement UUID (Maybe (Int64, Int32))
 -> Statement UUID (Maybe (Cursor SQLStore)))
-> Statement UUID (Maybe (Int64, Int32))
-> Statement UUID (Maybe (Cursor SQLStore))
forall a b. (a -> b) -> a -> b
$
        [maybeStatement|
    select 
      latest_transaction_no :: int8,
      latest_seq_no :: int4
    from stream_heads
    where stream_id = $1 :: uuid
  |]

-- | Get current stream version (local cursor)
getCurrentStreamVersionStatement :: Statement.Statement UUID (Maybe StreamVersion)
getCurrentStreamVersionStatement :: Statement UUID (Maybe StreamVersion)
getCurrentStreamVersionStatement =
    (UUID -> UUID)
-> (Maybe Int64 -> Maybe StreamVersion)
-> Statement UUID (Maybe Int64)
-> Statement UUID (Maybe StreamVersion)
forall a b c d.
(a -> b) -> (c -> d) -> Statement b c -> Statement a d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap UUID -> UUID
forall a. a -> a
id ((Int64 -> StreamVersion) -> Maybe Int64 -> Maybe StreamVersion
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int64 -> StreamVersion
StreamVersion) (Statement UUID (Maybe Int64)
 -> Statement UUID (Maybe StreamVersion))
-> Statement UUID (Maybe Int64)
-> Statement UUID (Maybe StreamVersion)
forall a b. (a -> b) -> a -> b
$
        [maybeStatement|
    select 
      stream_version :: int8
    from stream_heads
    where stream_id = $1 :: uuid
  |]

-- | Check if a stream exists
streamExistsStatement :: Statement.Statement UUID Bool
streamExistsStatement :: Statement UUID Bool
streamExistsStatement =
    [singletonStatement|
    select exists (
      select 1 
      from stream_heads
      where stream_id = $1 :: uuid
    ) :: bool
  |]

{- | Acquire an advisory lock on a stream for the duration of the transaction.

Uses PostgreSQL's advisory locks to prevent concurrent modifications
to the same stream while allowing parallel writes to different streams.
-}
lockStreamStatement :: Statement.Statement UUID ()
lockStreamStatement :: Statement UUID ()
lockStreamStatement = ByteString -> Params UUID -> Result () -> Bool -> Statement UUID ()
forall params result.
ByteString
-> Params params
-> Result result
-> Bool
-> Statement params result
Statement.Statement ByteString
sql Params UUID
encoder Result ()
decoder Bool
True
  where
    -- Use the two paramemeters version namespaced by 2.
    -- We *might* have collisions, but that's not a big deal: it it happens,
    -- it will just end up in over-serialization of a few transactions.
    sql :: ByteString
sql = ByteString
"select pg_advisory_xact_lock(2, hashtext($1::text)::int), true"
    encoder :: Params UUID
encoder = NullableOrNot Value UUID -> Params UUID
forall a. NullableOrNot Value a -> Params a
Encoders.param (Value UUID -> NullableOrNot Value UUID
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
Encoders.nonNullable Value UUID
Encoders.uuid)
    decoder :: Result ()
decoder = Result ()
Decoders.noResult

{- | Validate version expectation for a single stream.

Acquires an advisory lock on the stream before checking to ensure
consistency with concurrent writers.
-}
checkStreamVersion :: UUID -> ExpectedVersion SQLStore -> HasqlTransaction.Transaction (Maybe (VersionMismatch SQLStore))
checkStreamVersion :: UUID
-> ExpectedVersion SQLStore
-> Transaction (Maybe (VersionMismatch SQLStore))
checkStreamVersion UUID
streamId ExpectedVersion SQLStore
expectation = do
    -- Acquire advisory lock to prevent concurrent modifications
    UUID -> Statement UUID () -> Transaction ()
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID ()
lockStreamStatement

    case ExpectedVersion SQLStore
expectation of
        ExpectedVersion SQLStore
NoStream -> do
            exists <- UUID -> Statement UUID Bool -> Transaction Bool
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID Bool
streamExistsStatement
            if exists
                then
                    pure $
                        Just $
                            VersionMismatch
                                { streamId = StreamId streamId
                                , expectedVersion = NoStream
                                , actualVersion = Nothing
                                }
                else pure Nothing
        ExpectedVersion SQLStore
StreamExists -> do
            exists <- UUID -> Statement UUID Bool -> Transaction Bool
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID Bool
streamExistsStatement
            if not exists
                then
                    pure $
                        Just $
                            VersionMismatch
                                { streamId = StreamId streamId
                                , expectedVersion = StreamExists
                                , actualVersion = Nothing
                                }
                else pure Nothing
        ExactVersion Cursor SQLStore
expectedCursor -> do
            mbVersion <- UUID
-> Statement UUID (Maybe SQLCursor)
-> Transaction (Maybe SQLCursor)
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID (Maybe (Cursor SQLStore))
Statement UUID (Maybe SQLCursor)
getCurrentVersionStatement
            case mbVersion of
                Maybe SQLCursor
Nothing ->
                    Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VersionMismatch SQLStore)
 -> Transaction (Maybe (VersionMismatch SQLStore)))
-> Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a b. (a -> b) -> a -> b
$
                        VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a. a -> Maybe a
Just (VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore))
-> VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a b. (a -> b) -> a -> b
$
                            VersionMismatch
                                { streamId :: StreamId
streamId = UUID -> StreamId
StreamId UUID
streamId
                                , expectedVersion :: ExpectedVersion SQLStore
expectedVersion = Cursor SQLStore -> ExpectedVersion SQLStore
forall backend. Cursor backend -> ExpectedVersion backend
ExactVersion Cursor SQLStore
expectedCursor
                                , actualVersion :: Maybe (Cursor SQLStore)
actualVersion = Maybe (Cursor SQLStore)
Maybe SQLCursor
forall a. Maybe a
Nothing
                                }
                Just SQLCursor
actualVersion ->
                    if Cursor SQLStore
SQLCursor
expectedCursor SQLCursor -> SQLCursor -> Bool
forall a. Eq a => a -> a -> Bool
== SQLCursor
actualVersion
                        then Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VersionMismatch SQLStore)
forall a. Maybe a
Nothing
                        else
                            Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VersionMismatch SQLStore)
 -> Transaction (Maybe (VersionMismatch SQLStore)))
-> Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a b. (a -> b) -> a -> b
$
                                VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a. a -> Maybe a
Just (VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore))
-> VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a b. (a -> b) -> a -> b
$
                                    VersionMismatch
                                        { streamId :: StreamId
streamId = UUID -> StreamId
StreamId UUID
streamId
                                        , expectedVersion :: ExpectedVersion SQLStore
expectedVersion = Cursor SQLStore -> ExpectedVersion SQLStore
forall backend. Cursor backend -> ExpectedVersion backend
ExactVersion Cursor SQLStore
expectedCursor
                                        , actualVersion :: Maybe (Cursor SQLStore)
actualVersion = SQLCursor -> Maybe SQLCursor
forall a. a -> Maybe a
Just SQLCursor
actualVersion
                                        }
        ExactStreamVersion StreamVersion
expectedStreamVersion -> do
            mbStreamVersion <- UUID
-> Statement UUID (Maybe StreamVersion)
-> Transaction (Maybe StreamVersion)
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID (Maybe StreamVersion)
getCurrentStreamVersionStatement
            case mbStreamVersion of
                Maybe StreamVersion
Nothing ->
                    Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VersionMismatch SQLStore)
 -> Transaction (Maybe (VersionMismatch SQLStore)))
-> Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a b. (a -> b) -> a -> b
$
                        VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a. a -> Maybe a
Just (VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore))
-> VersionMismatch SQLStore -> Maybe (VersionMismatch SQLStore)
forall a b. (a -> b) -> a -> b
$
                            VersionMismatch
                                { streamId :: StreamId
streamId = UUID -> StreamId
StreamId UUID
streamId
                                , expectedVersion :: ExpectedVersion SQLStore
expectedVersion = StreamVersion -> ExpectedVersion SQLStore
forall backend. StreamVersion -> ExpectedVersion backend
ExactStreamVersion StreamVersion
expectedStreamVersion
                                , actualVersion :: Maybe (Cursor SQLStore)
actualVersion = Maybe (Cursor SQLStore)
Maybe SQLCursor
forall a. Maybe a
Nothing
                                }
                Just StreamVersion
actualStreamVersion ->
                    if StreamVersion
expectedStreamVersion StreamVersion -> StreamVersion -> Bool
forall a. Eq a => a -> a -> Bool
== StreamVersion
actualStreamVersion
                        then Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VersionMismatch SQLStore)
forall a. Maybe a
Nothing
                        else do
                            -- Get the actual cursor for the stream
                            mbCursor <- UUID
-> Statement UUID (Maybe SQLCursor)
-> Transaction (Maybe SQLCursor)
forall a b. a -> Statement a b -> Transaction b
HasqlTransaction.statement UUID
streamId Statement UUID (Maybe SQLCursor)
getStreamCursorStatement
                            pure $
                                Just $
                                    VersionMismatch
                                        { streamId = StreamId streamId
                                        , expectedVersion = ExactStreamVersion expectedStreamVersion
                                        , actualVersion = mbCursor
                                        }
        ExpectedVersion SQLStore
Any -> Maybe (VersionMismatch SQLStore)
-> Transaction (Maybe (VersionMismatch SQLStore))
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VersionMismatch SQLStore)
forall a. Maybe a
Nothing

{- | Validate version expectations for all event batches.

Acquires row-level locks on affected streams and checks that each
stream's current version matches the expected version. Returns
'Nothing' if all checks pass, or details of any mismatches.
-}
checkVersions :: forall t. Map StreamId (StreamWrite t SomeLatestEvent SQLStore) -> HasqlTransaction.Transaction (Maybe (ConsistencyErrorInfo SQLStore))
checkVersions :: forall (t :: * -> *).
Map StreamId (StreamWrite t SomeLatestEvent SQLStore)
-> Transaction (Maybe (ConsistencyErrorInfo SQLStore))
checkVersions Map StreamId (StreamWrite t SomeLatestEvent SQLStore)
batches = do
    let streamBatches :: [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
streamBatches = Map StreamId (StreamWrite t SomeLatestEvent SQLStore)
-> [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
forall k a. Map k a -> [(k, a)]
Map.toList Map StreamId (StreamWrite t SomeLatestEvent SQLStore)
batches
    mismatches <- [VersionMismatch SQLStore]
-> [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
-> Transaction [VersionMismatch SQLStore]
validateAllBatches [] [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
streamBatches
    pure $
        if null mismatches
            then Nothing
            else Just $ ConsistencyErrorInfo mismatches
  where
    validateAllBatches ::
        [VersionMismatch SQLStore] ->
        [(StreamId, StreamWrite t SomeLatestEvent SQLStore)] ->
        HasqlTransaction.Transaction [VersionMismatch SQLStore]
    validateAllBatches :: [VersionMismatch SQLStore]
-> [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
-> Transaction [VersionMismatch SQLStore]
validateAllBatches [VersionMismatch SQLStore]
acc [] = [VersionMismatch SQLStore]
-> Transaction [VersionMismatch SQLStore]
forall a. a -> Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VersionMismatch SQLStore]
acc
    validateAllBatches [VersionMismatch SQLStore]
acc ((StreamId
streamId, StreamWrite t SomeLatestEvent SQLStore
batch) : [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
rest) = do
        mbMismatch <- UUID
-> ExpectedVersion SQLStore
-> Transaction (Maybe (VersionMismatch SQLStore))
checkStreamVersion StreamId
streamId.toUUID StreamWrite t SomeLatestEvent SQLStore
batch.expectedVersion
        case mbMismatch of
            Maybe (VersionMismatch SQLStore)
Nothing -> [VersionMismatch SQLStore]
-> [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
-> Transaction [VersionMismatch SQLStore]
validateAllBatches [VersionMismatch SQLStore]
acc [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
rest
            Just VersionMismatch SQLStore
mismatch -> [VersionMismatch SQLStore]
-> [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
-> Transaction [VersionMismatch SQLStore]
validateAllBatches (VersionMismatch SQLStore
mismatch VersionMismatch SQLStore
-> [VersionMismatch SQLStore] -> [VersionMismatch SQLStore]
forall a. a -> [a] -> [a]
: [VersionMismatch SQLStore]
acc) [(StreamId, StreamWrite t SomeLatestEvent SQLStore)]
rest

-- | Get current cursor position for a stream
getStreamCursorStatement :: Statement.Statement UUID (Maybe SQLCursor)
getStreamCursorStatement :: Statement UUID (Maybe SQLCursor)
getStreamCursorStatement = ByteString
-> Params UUID
-> Result (Maybe SQLCursor)
-> Bool
-> Statement UUID (Maybe SQLCursor)
forall params result.
ByteString
-> Params params
-> Result result
-> Bool
-> Statement params result
Statement.Statement ByteString
sql Params UUID
encoder Result (Maybe SQLCursor)
decoder Bool
True
  where
    sql :: ByteString
sql = ByteString
"SELECT latest_transaction_no, latest_seq_no FROM stream_heads WHERE stream_id = $1"
    encoder :: Params UUID
encoder = NullableOrNot Value UUID -> Params UUID
forall a. NullableOrNot Value a -> Params a
E.param (Value UUID -> NullableOrNot Value UUID
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
E.nonNullable Value UUID
E.uuid)
    decoder :: Result (Maybe SQLCursor)
decoder = Row SQLCursor -> Result (Maybe SQLCursor)
forall a. Row a -> Result (Maybe a)
D.rowMaybe (Row SQLCursor -> Result (Maybe SQLCursor))
-> Row SQLCursor -> Result (Maybe SQLCursor)
forall a b. (a -> b) -> a -> b
$ do
        txNo <- NullableOrNot Value Int64 -> Row Int64
forall a. NullableOrNot Value a -> Row a
D.column (NullableOrNot Value Int64 -> Row Int64)
-> NullableOrNot Value Int64 -> Row Int64
forall a b. (a -> b) -> a -> b
$ Value Int64 -> NullableOrNot Value Int64
forall (decoder :: * -> *) a. decoder a -> NullableOrNot decoder a
D.nonNullable Value Int64
D.int8
        seqNo <- D.column $ D.nonNullable D.int4
        pure $ SQLCursor txNo seqNo