001/*
002 * Copyright 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet;
018
019import org.jspecify.annotations.NonNull;
020
021import javax.annotation.concurrent.ThreadSafe;
022import java.time.Duration;
023import java.time.Instant;
024import java.util.Optional;
025import java.util.concurrent.ConcurrentHashMap;
026import java.util.concurrent.ConcurrentMap;
027import java.util.function.Predicate;
028
029import static java.time.Duration.ZERO;
030import static java.time.Duration.between;
031import static java.time.Duration.ofHours;
032import static java.time.Duration.ofMinutes;
033import static java.util.Objects.requireNonNull;
034
035/**
036 * Compare-and-set persistence contract for MCP sessions.
037 *
038 * @author <a href="https://www.revetkn.com">Mark Allen</a>
039 */
040@ThreadSafe
041public interface McpSessionStore {
042        /**
043         * Persists a newly-created session.
044         *
045         * @param session the session to create
046         */
047        void create(@NonNull McpStoredSession session);
048
049        /**
050         * Loads a session by ID.
051         *
052         * @param sessionId the session ID
053         * @return the stored session, if it exists and has not expired
054         */
055        @NonNull
056        Optional<McpStoredSession> findBySessionId(@NonNull String sessionId);
057
058        /**
059         * Replaces a session using compare-and-set semantics.
060         *
061         * @param expected the currently-stored session snapshot
062         * @param updated the replacement session snapshot
063         * @return {@code true} if the replacement succeeded
064         */
065        @NonNull
066        Boolean replace(@NonNull McpStoredSession expected,
067                                                                        @NonNull McpStoredSession updated);
068
069        /**
070         * Deletes a session by ID.
071         *
072         * @param sessionId the session ID to delete
073         */
074        void deleteBySessionId(@NonNull String sessionId);
075
076        /**
077         * Acquires the default in-memory session store using Soklet's default idle timeout.
078         * <p>
079         * Expired sessions are reclaimed opportunistically during lookup and subsequent create activity; exact deletion timing is therefore best-effort rather than timer-driven.
080         *
081         * @return a new in-memory session store
082         */
083        @NonNull
084        static McpSessionStore fromInMemory() {
085                return new DefaultMcpSessionStore(ofHours(24));
086        }
087
088        /**
089         * Acquires the default in-memory session store using a caller-supplied idle timeout.
090         * <p>
091         * Expired sessions are reclaimed opportunistically during lookup and subsequent create activity; exact deletion timing is therefore best-effort rather than timer-driven.
092         *
093         * @param idleTimeout the idle timeout, or {@code Duration.ZERO} to disable idle expiry
094         * @return a new in-memory session store
095         */
096        @NonNull
097        static McpSessionStore fromInMemory(@NonNull Duration idleTimeout) {
098                requireNonNull(idleTimeout);
099
100                if (idleTimeout.isNegative())
101                        throw new IllegalArgumentException("Idle timeout must not be negative.");
102
103                return new DefaultMcpSessionStore(idleTimeout);
104        }
105}
106
107final class DefaultMcpSessionStore implements McpSessionStore {
108        @NonNull
109        private static final Duration DEFAULT_SWEEP_INTERVAL;
110        @NonNull
111        private final Duration idleTimeout;
112        @NonNull
113        private final ConcurrentMap<String, McpStoredSession> sessions;
114        @NonNull
115        private volatile Predicate<String> pinnedSessionPredicate;
116        @NonNull
117        private volatile Instant lastSweepAt;
118
119        static {
120                DEFAULT_SWEEP_INTERVAL = ofMinutes(1);
121        }
122
123        DefaultMcpSessionStore(@NonNull Duration idleTimeout) {
124                requireNonNull(idleTimeout);
125                this.idleTimeout = idleTimeout;
126                this.sessions = new ConcurrentHashMap<>();
127                this.pinnedSessionPredicate = sessionId -> false;
128                this.lastSweepAt = Instant.EPOCH;
129        }
130
131        @Override
132        public void create(@NonNull McpStoredSession session) {
133                requireNonNull(session);
134                maybeSweepExpiredSessions();
135
136                McpStoredSession previous = this.sessions.putIfAbsent(session.sessionId(), session);
137
138                if (previous != null)
139                        throw new IllegalStateException("Session with ID '%s' already exists".formatted(session.sessionId()));
140        }
141
142        @NonNull
143        @Override
144        public Optional<McpStoredSession> findBySessionId(@NonNull String sessionId) {
145                requireNonNull(sessionId);
146
147                McpStoredSession storedSession = this.sessions.get(sessionId);
148
149                if (storedSession == null)
150                        return Optional.empty();
151
152                if (isExpired(storedSession)) {
153                        this.sessions.remove(sessionId, storedSession);
154                        return Optional.empty();
155                }
156
157                return Optional.of(storedSession);
158        }
159
160        @NonNull
161        @Override
162        public Boolean replace(@NonNull McpStoredSession expected,
163                                                                                                 @NonNull McpStoredSession updated) {
164                requireNonNull(expected);
165                requireNonNull(updated);
166
167                if (!expected.sessionId().equals(updated.sessionId()))
168                        throw new IllegalArgumentException("Expected and updated sessions must have the same session ID.");
169
170                if (updated.version().longValue() <= expected.version().longValue())
171                        throw new IllegalArgumentException("Updated session version must be strictly greater than expected version.");
172
173                if (isExpired(expected))
174                        return false;
175
176                return this.sessions.replace(expected.sessionId(), expected, updated);
177        }
178
179        @Override
180        public void deleteBySessionId(@NonNull String sessionId) {
181                requireNonNull(sessionId);
182                this.sessions.remove(sessionId);
183        }
184
185        void pinnedSessionPredicate(@NonNull Predicate<String> pinnedSessionPredicate) {
186                requireNonNull(pinnedSessionPredicate);
187                this.pinnedSessionPredicate = pinnedSessionPredicate;
188        }
189
190        boolean containsSessionId(@NonNull String sessionId) {
191                requireNonNull(sessionId);
192                return this.sessions.containsKey(sessionId);
193        }
194
195        @NonNull
196        Optional<McpStoredSession> takeExpiredSession(@NonNull String sessionId) {
197                requireNonNull(sessionId);
198
199                McpStoredSession storedSession = this.sessions.get(sessionId);
200
201                if (storedSession == null || !isExpired(storedSession))
202                        return Optional.empty();
203
204                return this.sessions.remove(sessionId, storedSession)
205                                ? Optional.of(storedSession)
206                                : Optional.empty();
207        }
208
209        private boolean isExpired(@NonNull McpStoredSession storedSession) {
210                requireNonNull(storedSession);
211
212                if (ZERO.equals(this.idleTimeout))
213                        return false;
214
215                if (storedSession.terminatedAt() != null)
216                        return false;
217
218                if (this.pinnedSessionPredicate.test(storedSession.sessionId()))
219                        return false;
220
221                Duration idleDuration = between(storedSession.lastActivityAt(), Instant.now());
222                return idleDuration.compareTo(this.idleTimeout) > 0;
223        }
224
225        private void maybeSweepExpiredSessions() {
226                if (ZERO.equals(this.idleTimeout))
227                        return;
228
229                Instant now = Instant.now();
230                Duration sweepInterval = sweepInterval();
231
232                if (between(this.lastSweepAt, now).compareTo(sweepInterval) < 0)
233                        return;
234
235                this.lastSweepAt = now;
236
237                for (var entry : this.sessions.entrySet()) {
238                        McpStoredSession storedSession = entry.getValue();
239
240                        if (isExpired(storedSession))
241                                this.sessions.remove(entry.getKey(), storedSession);
242                }
243        }
244
245        @NonNull
246        private Duration sweepInterval() {
247                return this.idleTimeout.compareTo(DEFAULT_SWEEP_INTERVAL) < 0 ? this.idleTimeout : DEFAULT_SWEEP_INTERVAL;
248        }
249}