2023-06-18 10:20:11 +01:00
{- # LANGUAGE DeriveAnyClass # -}
{- # LANGUAGE DeriveGeneric # -}
{- # LANGUAGE DuplicateRecordFields # -}
{- # LANGUAGE LambdaCase # -}
{- # LANGUAGE NamedFieldPuns # -}
{- # LANGUAGE OverloadedStrings # -}
{- # LANGUAGE QuasiQuotes # -}
{- # LANGUAGE ScopedTypeVariables # -}
{- # LANGUAGE TypeOperators # -}
module Simplex.Chat.Store.Shared where
import Control.Concurrent.STM ( stateTVar )
import Control.Exception ( Exception )
import qualified Control.Exception as E
import Control.Monad.Except
import Crypto.Random ( ChaChaDRG , randomBytesGenerate )
import Data.Aeson ( ToJSON )
import qualified Data.Aeson as J
import Data.ByteString.Char8 ( ByteString )
import qualified Data.ByteString.Base64 as B64
import Data.Int ( Int64 )
import Data.Maybe ( fromMaybe , isJust , listToMaybe )
import Data.Text ( Text )
import qualified Data.Text as T
import Data.Time.Clock ( UTCTime ( .. ) , getCurrentTime )
import Database.SQLite.Simple ( NamedParam ( .. ) , Only ( .. ) , Query , SQLError , ( :. ) ( .. ) )
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.QQ ( sql )
import GHC.Generics ( Generic )
import Simplex.Chat.Messages
import Simplex.Chat.Protocol
import Simplex.Chat.Types
2023-07-21 21:32:28 +01:00
import Simplex.Chat.Types.Preferences
2023-06-18 10:20:11 +01:00
import Simplex.Messaging.Agent.Protocol ( AgentMsgId , ConnId , UserId )
import Simplex.Messaging.Agent.Store.SQLite ( firstRow , maybeFirstRow )
import Simplex.Messaging.Parsers ( dropPrefix , sumTypeJSON )
2023-07-09 23:24:38 +01:00
import Simplex.Messaging.Util ( allFinally )
2023-06-18 10:20:11 +01:00
import UnliftIO.STM
-- These error type constructors must be added to mobile apps
data StoreError
= SEDuplicateName
| SEUserNotFound { userId :: UserId }
| SEUserNotFoundByName { contactName :: ContactName }
| SEUserNotFoundByContactId { contactId :: ContactId }
| SEUserNotFoundByGroupId { groupId :: GroupId }
| SEUserNotFoundByFileId { fileId :: FileTransferId }
| SEUserNotFoundByContactRequestId { contactRequestId :: Int64 }
| SEContactNotFound { contactId :: ContactId }
| SEContactNotFoundByName { contactName :: ContactName }
| SEContactNotReady { contactName :: ContactName }
| SEDuplicateContactLink
| SEUserContactLinkNotFound
| SEContactRequestNotFound { contactRequestId :: Int64 }
| SEContactRequestNotFoundByName { contactName :: ContactName }
| SEGroupNotFound { groupId :: GroupId }
| SEGroupNotFoundByName { groupName :: GroupName }
| SEGroupMemberNameNotFound { groupId :: GroupId , groupMemberName :: ContactName }
| SEGroupMemberNotFound { groupMemberId :: GroupMemberId }
2023-06-22 20:38:09 +04:00
| SEGroupMemberNotFoundByMemberId { memberId :: MemberId }
2023-06-18 10:20:11 +01:00
| SEGroupWithoutUser
| SEDuplicateGroupMember
| SEGroupAlreadyJoined
| SEGroupInvitationNotFound
| SESndFileNotFound { fileId :: FileTransferId }
| SESndFileInvalid { fileId :: FileTransferId }
| SERcvFileNotFound { fileId :: FileTransferId }
| SERcvFileDescrNotFound { fileId :: FileTransferId }
| SEFileNotFound { fileId :: FileTransferId }
| SERcvFileInvalid { fileId :: FileTransferId }
| SERcvFileInvalidDescrPart
| SESharedMsgIdNotFoundByFileId { fileId :: FileTransferId }
| SEFileIdNotFoundBySharedMsgId { sharedMsgId :: SharedMsgId }
| SESndFileNotFoundXFTP { agentSndFileId :: AgentSndFileId }
| SERcvFileNotFoundXFTP { agentRcvFileId :: AgentRcvFileId }
| SEConnectionNotFound { agentConnId :: AgentConnId }
| SEConnectionNotFoundById { connId :: Int64 }
| SEPendingConnectionNotFound { connId :: Int64 }
| SEIntroNotFound
| SEUniqueID
| SEInternalError { message :: String }
| SENoMsgDelivery { connId :: Int64 , agentMsgId :: AgentMsgId }
| SEBadChatItem { itemId :: ChatItemId }
| SEChatItemNotFound { itemId :: ChatItemId }
| SEChatItemNotFoundByText { text :: Text }
| SEChatItemSharedMsgIdNotFound { sharedMsgId :: SharedMsgId }
| SEChatItemNotFoundByFileId { fileId :: FileTransferId }
| SEChatItemNotFoundByGroupId { groupId :: GroupId }
| SEProfileNotFound { profileId :: Int64 }
| SEDuplicateGroupLink { groupInfo :: GroupInfo }
| SEGroupLinkNotFound { groupInfo :: GroupInfo }
| SEHostMemberIdNotFound { groupId :: Int64 }
| SEContactNotFoundByFileId { fileId :: FileTransferId }
deriving ( Show , Exception , Generic )
instance ToJSON StoreError where
toJSON = J . genericToJSON . sumTypeJSON $ dropPrefix " SE "
toEncoding = J . genericToEncoding . sumTypeJSON $ dropPrefix " SE "
insertedRowId :: DB . Connection -> IO Int64
insertedRowId db = fromOnly . head <$> DB . query_ db " SELECT last_insert_rowid() "
checkConstraint :: StoreError -> ExceptT StoreError IO a -> ExceptT StoreError IO a
checkConstraint err action = ExceptT $ runExceptT action ` E . catch ` ( pure . Left . handleSQLError err )
handleSQLError :: StoreError -> SQLError -> StoreError
handleSQLError err e
| DB . sqlError e == DB . ErrorConstraint = err
| otherwise = SEInternalError $ show e
2023-07-09 23:24:38 +01:00
storeFinally :: ExceptT StoreError IO a -> ExceptT StoreError IO b -> ExceptT StoreError IO a
storeFinally = allFinally mkStoreError
{- # INLINE storeFinally # -}
mkStoreError :: E . SomeException -> StoreError
mkStoreError = SEInternalError . show
{- # INLINE mkStoreError # -}
2023-06-18 10:20:11 +01:00
fileInfoQuery :: Query
fileInfoQuery =
[ sql |
SELECT f . file_id , f . ci_file_status , f . file_path
FROM chat_items i
JOIN files f ON f . chat_item_id = i . chat_item_id
| ]
toFileInfo :: ( Int64 , Maybe ACIFileStatus , Maybe FilePath ) -> CIFileInfo
toFileInfo ( fileId , fileStatus , filePath ) = CIFileInfo { fileId , fileStatus , filePath }
type EntityIdsRow = ( Maybe Int64 , Maybe Int64 , Maybe Int64 , Maybe Int64 , Maybe Int64 )
type ConnectionRow = ( Int64 , ConnId , Int , Maybe Int64 , Maybe Int64 , Bool , Maybe GroupLinkId , Maybe Int64 , ConnStatus , ConnType , LocalAlias ) :. EntityIdsRow :. ( UTCTime , Maybe Text , Maybe UTCTime , Int )
type MaybeConnectionRow = ( Maybe Int64 , Maybe ConnId , Maybe Int , Maybe Int64 , Maybe Int64 , Maybe Bool , Maybe GroupLinkId , Maybe Int64 , Maybe ConnStatus , Maybe ConnType , Maybe LocalAlias ) :. EntityIdsRow :. ( Maybe UTCTime , Maybe Text , Maybe UTCTime , Maybe Int )
toConnection :: ConnectionRow -> Connection
toConnection ( ( connId , acId , connLevel , viaContact , viaUserContactLink , viaGroupLink , groupLinkId , customUserProfileId , connStatus , connType , localAlias ) :. ( contactId , groupMemberId , sndFileId , rcvFileId , userContactLinkId ) :. ( createdAt , code_ , verifiedAt_ , authErrCounter ) ) =
let entityId = entityId_ connType
connectionCode = SecurityCode <$> code_ <*> verifiedAt_
in Connection { connId , agentConnId = AgentConnId acId , connLevel , viaContact , viaUserContactLink , viaGroupLink , groupLinkId , customUserProfileId , connStatus , connType , localAlias , entityId , connectionCode , authErrCounter , createdAt }
where
entityId_ :: ConnType -> Maybe Int64
entityId_ ConnContact = contactId
entityId_ ConnMember = groupMemberId
entityId_ ConnRcvFile = rcvFileId
entityId_ ConnSndFile = sndFileId
entityId_ ConnUserContact = userContactLinkId
toMaybeConnection :: MaybeConnectionRow -> Maybe Connection
toMaybeConnection ( ( Just connId , Just agentConnId , Just connLevel , viaContact , viaUserContactLink , Just viaGroupLink , groupLinkId , customUserProfileId , Just connStatus , Just connType , Just localAlias ) :. ( contactId , groupMemberId , sndFileId , rcvFileId , userContactLinkId ) :. ( Just createdAt , code_ , verifiedAt_ , Just authErrCounter ) ) =
Just $ toConnection ( ( connId , agentConnId , connLevel , viaContact , viaUserContactLink , viaGroupLink , groupLinkId , customUserProfileId , connStatus , connType , localAlias ) :. ( contactId , groupMemberId , sndFileId , rcvFileId , userContactLinkId ) :. ( createdAt , code_ , verifiedAt_ , authErrCounter ) )
toMaybeConnection _ = Nothing
createConnection_ :: DB . Connection -> UserId -> ConnType -> Maybe Int64 -> ConnId -> Maybe ContactId -> Maybe Int64 -> Maybe ProfileId -> Int -> UTCTime -> IO Connection
createConnection_ db userId connType entityId acId viaContact viaUserContactLink customUserProfileId connLevel currentTs = do
viaLinkGroupId :: Maybe Int64 <- fmap join . forM viaUserContactLink $ \ ucLinkId ->
maybeFirstRow fromOnly $ DB . query db " SELECT group_id FROM user_contact_links WHERE user_id = ? AND user_contact_link_id = ? AND group_id IS NOT NULL " ( userId , ucLinkId )
let viaGroupLink = isJust viaLinkGroupId
DB . execute
db
[ sql |
INSERT INTO connections (
user_id , agent_conn_id , conn_level , via_contact , via_user_contact_link , via_group_link , custom_user_profile_id , conn_status , conn_type ,
contact_id , group_member_id , snd_file_id , rcv_file_id , user_contact_link_id , created_at , updated_at
) VALUES ( ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? )
| ]
( ( userId , acId , connLevel , viaContact , viaUserContactLink , viaGroupLink , customUserProfileId , ConnNew , connType )
:. ( ent ConnContact , ent ConnMember , ent ConnSndFile , ent ConnRcvFile , ent ConnUserContact , currentTs , currentTs )
)
connId <- insertedRowId db
pure Connection { connId , agentConnId = AgentConnId acId , connType , entityId , viaContact , viaUserContactLink , viaGroupLink , groupLinkId = Nothing , customUserProfileId , connLevel , connStatus = ConnNew , localAlias = " " , createdAt = currentTs , connectionCode = Nothing , authErrCounter = 0 }
where
ent ct = if connType == ct then entityId else Nothing
setCommandConnId :: DB . Connection -> User -> CommandId -> Int64 -> IO ()
setCommandConnId db User { userId } cmdId connId = do
updatedAt <- getCurrentTime
DB . execute
db
[ sql |
UPDATE commands
SET connection_id = ? , updated_at = ?
WHERE user_id = ? AND command_id = ?
| ]
( connId , updatedAt , userId , cmdId )
createContact_ :: DB . Connection -> UserId -> Int64 -> Profile -> LocalAlias -> Maybe Int64 -> UTCTime -> Maybe UTCTime -> ExceptT StoreError IO ( Text , ContactId , ProfileId )
createContact_ db userId connId Profile { displayName , fullName , image , contactLink , preferences } localAlias viaGroup currentTs chatTs =
ExceptT . withLocalDisplayName db userId displayName $ \ ldn -> do
DB . execute
db
" INSERT INTO contact_profiles (display_name, full_name, image, contact_link, user_id, local_alias, preferences, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?) "
( displayName , fullName , image , contactLink , userId , localAlias , preferences , currentTs , currentTs )
profileId <- insertedRowId db
DB . execute
db
" INSERT INTO contacts (contact_profile_id, local_display_name, user_id, via_group, created_at, updated_at, chat_ts) VALUES (?,?,?,?,?,?,?) "
( profileId , ldn , userId , viaGroup , currentTs , currentTs , chatTs )
contactId <- insertedRowId db
DB . execute db " UPDATE connections SET contact_id = ?, updated_at = ? WHERE connection_id = ? " ( contactId , currentTs , connId )
pure $ Right ( ldn , contactId , profileId )
deleteUnusedIncognitoProfileById_ :: DB . Connection -> User -> ProfileId -> IO ()
deleteUnusedIncognitoProfileById_ db User { userId } profile_id =
DB . executeNamed
db
[ sql |
DELETE FROM contact_profiles
WHERE user_id = : user_id AND contact_profile_id = : profile_id AND incognito = 1
AND 1 NOT IN (
SELECT 1 FROM connections
WHERE user_id = : user_id AND custom_user_profile_id = : profile_id LIMIT 1
)
AND 1 NOT IN (
SELECT 1 FROM group_members
WHERE user_id = : user_id AND member_profile_id = : profile_id LIMIT 1
)
| ]
[ " :user_id " := userId , " :profile_id " := profile_id ]
2023-07-13 23:48:25 +01:00
type ContactRow = ( ContactId , ProfileId , ContactName , Maybe Int64 , ContactName , Text , Maybe ImageData , Maybe ConnReqContact , LocalAlias , Bool ) :. ( Maybe Bool , Maybe Bool , Bool , Maybe Preferences , Preferences , UTCTime , UTCTime , Maybe UTCTime )
2023-06-18 10:20:11 +01:00
toContact :: User -> ContactRow :. ConnectionRow -> Contact
2023-07-13 23:48:25 +01:00
toContact user ( ( ( contactId , profileId , localDisplayName , viaGroup , displayName , fullName , image , contactLink , localAlias , contactUsed ) :. ( enableNtfs_ , sendRcpts , favorite , preferences , userPreferences , createdAt , updatedAt , chatTs ) ) :. connRow ) =
2023-06-18 10:20:11 +01:00
let profile = LocalProfile { profileId , displayName , fullName , image , contactLink , preferences , localAlias }
activeConn = toConnection connRow
2023-07-13 23:48:25 +01:00
chatSettings = ChatSettings { enableNtfs = fromMaybe True enableNtfs_ , sendRcpts , favorite }
2023-06-18 10:20:11 +01:00
mergedPreferences = contactUserPreferences user userPreferences preferences $ connIncognito activeConn
in Contact { contactId , localDisplayName , profile , activeConn , viaGroup , contactUsed , chatSettings , userPreferences , mergedPreferences , createdAt , updatedAt , chatTs }
toContactOrError :: User -> ContactRow :. MaybeConnectionRow -> Either StoreError Contact
2023-07-13 23:48:25 +01:00
toContactOrError user ( ( ( contactId , profileId , localDisplayName , viaGroup , displayName , fullName , image , contactLink , localAlias , contactUsed ) :. ( enableNtfs_ , sendRcpts , favorite , preferences , userPreferences , createdAt , updatedAt , chatTs ) ) :. connRow ) =
2023-06-18 10:20:11 +01:00
let profile = LocalProfile { profileId , displayName , fullName , image , contactLink , preferences , localAlias }
2023-07-13 23:48:25 +01:00
chatSettings = ChatSettings { enableNtfs = fromMaybe True enableNtfs_ , sendRcpts , favorite }
2023-06-18 10:20:11 +01:00
in case toMaybeConnection connRow of
Just activeConn ->
let mergedPreferences = contactUserPreferences user userPreferences preferences $ connIncognito activeConn
in Right Contact { contactId , localDisplayName , profile , activeConn , viaGroup , contactUsed , chatSettings , userPreferences , mergedPreferences , createdAt , updatedAt , chatTs }
_ -> Left $ SEContactNotReady localDisplayName
getProfileById :: DB . Connection -> UserId -> Int64 -> ExceptT StoreError IO LocalProfile
getProfileById db userId profileId =
ExceptT . firstRow toProfile ( SEProfileNotFound profileId ) $
DB . query
db
[ sql |
SELECT cp . display_name , cp . full_name , cp . image , cp . contact_link , cp . local_alias , cp . preferences -- , ct.user_preferences
FROM contact_profiles cp
WHERE cp . user_id = ? AND cp . contact_profile_id = ?
| ]
( userId , profileId )
where
toProfile :: ( ContactName , Text , Maybe ImageData , Maybe ConnReqContact , LocalAlias , Maybe Preferences ) -> LocalProfile
toProfile ( displayName , fullName , image , contactLink , localAlias , preferences ) = LocalProfile { profileId , displayName , fullName , image , contactLink , preferences , localAlias }
type ContactRequestRow = ( Int64 , ContactName , AgentInvId , Int64 , AgentConnId , Int64 , ContactName , Text , Maybe ImageData , Maybe ConnReqContact ) :. ( Maybe XContactId , Maybe Preferences , UTCTime , UTCTime )
toContactRequest :: ContactRequestRow -> UserContactRequest
toContactRequest ( ( contactRequestId , localDisplayName , agentInvitationId , userContactLinkId , agentContactConnId , profileId , displayName , fullName , image , contactLink ) :. ( xContactId , preferences , createdAt , updatedAt ) ) = do
let profile = Profile { displayName , fullName , image , contactLink , preferences }
in UserContactRequest { contactRequestId , agentInvitationId , userContactLinkId , agentContactConnId , localDisplayName , profileId , profile , xContactId , createdAt , updatedAt }
userQuery :: Query
userQuery =
[ sql |
2023-07-13 23:48:25 +01:00
SELECT u . user_id , u . agent_user_id , u . contact_id , ucp . contact_profile_id , u . active_user , u . local_display_name , ucp . full_name , ucp . image , ucp . contact_link , ucp . preferences ,
u . show_ntfs , u . send_rcpts_contacts , u . send_rcpts_small_groups , u . view_pwd_hash , u . view_pwd_salt
2023-06-18 10:20:11 +01:00
FROM users u
JOIN contacts uct ON uct . contact_id = u . contact_id
JOIN contact_profiles ucp ON ucp . contact_profile_id = uct . contact_profile_id
| ]
2023-07-13 23:48:25 +01:00
toUser :: ( UserId , UserId , ContactId , ProfileId , Bool , ContactName , Text , Maybe ImageData , Maybe ConnReqContact , Maybe Preferences ) :. ( Bool , Bool , Bool , Maybe B64UrlByteString , Maybe B64UrlByteString ) -> User
toUser ( ( userId , auId , userContactId , profileId , activeUser , displayName , fullName , image , contactLink , userPreferences ) :. ( showNtfs , sendRcptsContacts , sendRcptsSmallGroups , viewPwdHash_ , viewPwdSalt_ ) ) =
User { userId , agentUserId = AgentUserId auId , userContactId , localDisplayName = displayName , profile , activeUser , fullPreferences , showNtfs , sendRcptsContacts , sendRcptsSmallGroups , viewPwdHash }
2023-06-18 10:20:11 +01:00
where
profile = LocalProfile { profileId , displayName , fullName , image , contactLink , preferences = userPreferences , localAlias = " " }
fullPreferences = mergePreferences Nothing userPreferences
viewPwdHash = UserPwdHash <$> viewPwdHash_ <*> viewPwdSalt_
toPendingContactConnection :: ( Int64 , ConnId , ConnStatus , Maybe ByteString , Maybe Int64 , Maybe GroupLinkId , Maybe Int64 , Maybe ConnReqInvitation , LocalAlias , UTCTime , UTCTime ) -> PendingContactConnection
toPendingContactConnection ( pccConnId , acId , pccConnStatus , connReqHash , viaUserContactLink , groupLinkId , customUserProfileId , connReqInv , localAlias , createdAt , updatedAt ) =
PendingContactConnection { pccConnId , pccAgentConnId = AgentConnId acId , pccConnStatus , viaContactUri = isJust connReqHash , viaUserContactLink , groupLinkId , customUserProfileId , connReqInv , localAlias , createdAt , updatedAt }
-- | Saves unique local display name based on passed displayName, suffixed with _N if required.
-- This function should be called inside transaction.
withLocalDisplayName :: forall a . DB . Connection -> UserId -> Text -> ( Text -> IO ( Either StoreError a ) ) -> IO ( Either StoreError a )
withLocalDisplayName db userId displayName action = getLdnSuffix >>= ( ` tryCreateName ` 20 )
where
getLdnSuffix :: IO Int
getLdnSuffix =
maybe 0 ( ( + 1 ) . fromOnly ) . listToMaybe
<$> DB . queryNamed
db
[ sql |
SELECT ldn_suffix FROM display_names
WHERE user_id = : user_id AND ldn_base = : display_name
ORDER BY ldn_suffix DESC
LIMIT 1
| ]
[ " :user_id " := userId , " :display_name " := displayName ]
tryCreateName :: Int -> Int -> IO ( Either StoreError a )
tryCreateName _ 0 = pure $ Left SEDuplicateName
tryCreateName ldnSuffix attempts = do
currentTs <- getCurrentTime
let ldn = displayName <> ( if ldnSuffix == 0 then " " else T . pack $ '_' : show ldnSuffix )
E . try ( insertName ldn currentTs ) >>= \ case
Right () -> action ldn
Left e
| DB . sqlError e == DB . ErrorConstraint -> tryCreateName ( ldnSuffix + 1 ) ( attempts - 1 )
| otherwise -> E . throwIO e
where
insertName ldn ts =
DB . execute
db
[ sql |
INSERT INTO display_names
( local_display_name , ldn_base , ldn_suffix , user_id , created_at , updated_at )
VALUES ( ? , ? , ? , ? , ? , ? )
| ]
( ldn , displayName , ldnSuffix , userId , ts , ts )
createWithRandomId :: forall a . TVar ChaChaDRG -> ( ByteString -> IO a ) -> ExceptT StoreError IO a
createWithRandomId = createWithRandomBytes 12
createWithRandomBytes :: forall a . Int -> TVar ChaChaDRG -> ( ByteString -> IO a ) -> ExceptT StoreError IO a
createWithRandomBytes size gVar create = tryCreate 3
where
tryCreate :: Int -> ExceptT StoreError IO a
tryCreate 0 = throwError SEUniqueID
tryCreate n = do
id' <- liftIO $ encodedRandomBytes gVar size
liftIO ( E . try $ create id' ) >>= \ case
Right x -> pure x
Left e
| DB . sqlError e == DB . ErrorConstraint -> tryCreate ( n - 1 )
| otherwise -> throwError . SEInternalError $ show e
encodedRandomBytes :: TVar ChaChaDRG -> Int -> IO ByteString
encodedRandomBytes gVar = fmap B64 . encode . randomBytes gVar
randomBytes :: TVar ChaChaDRG -> Int -> IO ByteString
randomBytes gVar = atomically . stateTVar gVar . randomBytesGenerate