001/*
002 * Copyright 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet;
018
019import org.jspecify.annotations.NonNull;
020import org.jspecify.annotations.Nullable;
021
022import javax.annotation.concurrent.NotThreadSafe;
023import javax.annotation.concurrent.ThreadSafe;
024import java.time.Duration;
025import java.time.Instant;
026import java.util.ArrayList;
027import java.util.List;
028import java.util.Optional;
029import java.util.concurrent.ConcurrentHashMap;
030import java.util.concurrent.ConcurrentMap;
031import java.util.concurrent.atomic.AtomicInteger;
032import java.util.function.Predicate;
033
034import static java.time.Duration.ZERO;
035import static java.time.Duration.between;
036import static java.time.Duration.ofHours;
037import static java.time.Duration.ofMinutes;
038import static java.util.Objects.requireNonNull;
039
040/**
041 * Compare-and-set persistence contract for MCP sessions.
042 *
043 * @author <a href="https://www.revetkn.com">Mark Allen</a>
044 */
045@ThreadSafe
046public interface McpSessionStore {
047        /**
048         * Creates and admits a new session for the given request and endpoint class.
049         * <p>
050         * Implementations are responsible for generating a valid MCP session ID and
051         * enforcing their own concurrency limits atomically with persistence. Returning
052         * {@link Optional#empty()} declines admission before session state is created
053         * (for example, when a concurrent-session limit is reached).
054         * <p>
055         * Returned sessions must be new, uninitialized, unterminated sessions for
056         * {@code endpointClass}. Soklet will complete initialization with
057         * {@link #replace(McpStoredSession, McpStoredSession)} after the endpoint's
058         * {@link McpEndpoint#initialize(McpInitializationContext, McpSessionContext)}
059         * callback succeeds.
060         *
061         * @param request the initialize request
062         * @param endpointClass the MCP endpoint class that will own the session
063         * @return the newly-created stored session, or empty if the store declined admission
064         */
065        @NonNull
066        Optional<McpStoredSession> create(@NonNull Request request,
067                                                                                                                                                @NonNull Class<? extends McpEndpoint> endpointClass);
068
069        /**
070         * Loads a session by ID.
071         *
072         * @param sessionId the session ID
073         * @return the stored session, if it exists and has not expired
074         */
075        @NonNull
076        Optional<McpStoredSession> findBySessionId(@NonNull String sessionId);
077
078        /**
079         * Replaces a session using compare-and-set semantics.
080         *
081         * @param expected the currently-stored session snapshot
082         * @param updated the replacement session snapshot
083         * @return {@code true} if the replacement succeeded
084         */
085        @NonNull
086        Boolean replace(@NonNull McpStoredSession expected,
087                                                                        @NonNull McpStoredSession updated);
088
089        /**
090         * Deletes a session by ID.
091         *
092         * @param sessionId the session ID to delete
093         */
094        void deleteBySessionId(@NonNull String sessionId);
095
096        /**
097         * Acquires a builder for Soklet's default in-memory MCP session store.
098         *
099         * @return a new MCP session store builder
100         */
101        @NonNull
102        static Builder builder() {
103                return new Builder();
104        }
105
106        /**
107         * Acquires Soklet's default in-memory MCP session store.
108         *
109         * @return a new in-memory session store
110         */
111        @NonNull
112        static McpSessionStore fromDefaults() {
113                return builder().build();
114        }
115
116        /**
117         * Acquires the default in-memory session store using Soklet's default idle timeout.
118         * <p>
119         * Expired sessions are reclaimed opportunistically during lookup and subsequent
120         * session-creation activity; exact deletion timing is therefore best-effort
121         * rather than timer-driven.
122         *
123         * @return a new in-memory session store
124         */
125        @NonNull
126        static McpSessionStore fromInMemory() {
127                return fromDefaults();
128        }
129
130        /**
131         * Builder for Soklet's default in-memory MCP session store.
132         */
133        @NotThreadSafe
134        final class Builder {
135                @Nullable
136                Duration idleTimeout;
137                @Nullable
138                IdGenerator<String> sessionIdGenerator;
139                @Nullable
140                Integer concurrentSessionLimit;
141
142                private Builder() {}
143
144                /**
145                 * Sets the idle timeout, or {@link Duration#ZERO} to disable idle expiry.
146                 *
147                 * @param idleTimeout the idle timeout, or {@code null} for the default
148                 * @return this builder
149                 */
150                @NonNull
151                public Builder idleTimeout(@Nullable Duration idleTimeout) {
152                        this.idleTimeout = idleTimeout;
153                        return this;
154                }
155
156                /**
157                 * Sets the generator used for newly-created MCP session IDs.
158                 * <p>
159                 * Custom generators must return globally unique, cryptographically strong,
160                 * visible-ASCII IDs suitable for {@code MCP-Session-Id} header values.
161                 *
162                 * @param sessionIdGenerator the session ID generator, or {@code null} for the default
163                 * @return this builder
164                 */
165                @NonNull
166                public Builder sessionIdGenerator(@Nullable IdGenerator<String> sessionIdGenerator) {
167                        this.sessionIdGenerator = sessionIdGenerator;
168                        return this;
169                }
170
171                /**
172                 * Sets the concurrent MCP session limit.
173                 * <p>
174                 * A value of {@code 0} disables the in-memory store's session cap.
175                 *
176                 * @param concurrentSessionLimit the concurrent MCP session limit, or {@code null} for the default
177                 * @return this builder
178                 */
179                @NonNull
180                public Builder concurrentSessionLimit(@Nullable Integer concurrentSessionLimit) {
181                        this.concurrentSessionLimit = concurrentSessionLimit;
182                        return this;
183                }
184
185                /**
186                 * Builds the MCP session store.
187                 *
188                 * @return the built MCP session store
189                 */
190                @NonNull
191                public McpSessionStore build() {
192                        return new DefaultMcpSessionStore(this);
193                }
194        }
195}
196
197final class DefaultMcpSessionStore implements McpSessionStore {
198        @NonNull
199        private static final Duration DEFAULT_IDLE_TIMEOUT;
200        @NonNull
201        private static final Duration DEFAULT_SWEEP_INTERVAL;
202        @NonNull
203        private static final Integer DEFAULT_CONCURRENT_SESSION_LIMIT;
204        @NonNull
205        private final Duration idleTimeout;
206        @NonNull
207        private final IdGenerator<String> sessionIdGenerator;
208        @NonNull
209        private final Integer concurrentSessionLimit;
210        @NonNull
211        private final ConcurrentMap<String, McpStoredSession> sessions;
212        @NonNull
213        private final ConcurrentMap<String, Boolean> activeLimitedSessionIds;
214        @NonNull
215        private final AtomicInteger activeLimitedSessionCount;
216        @NonNull
217        private volatile Predicate<String> pinnedSessionPredicate;
218        @NonNull
219        private volatile Instant lastSweepAt;
220
221        static {
222                DEFAULT_IDLE_TIMEOUT = ofHours(24);
223                DEFAULT_SWEEP_INTERVAL = ofMinutes(1);
224                DEFAULT_CONCURRENT_SESSION_LIMIT = 8_192;
225        }
226
227        DefaultMcpSessionStore(McpSessionStore.@NonNull Builder builder) {
228                requireNonNull(builder);
229                this.idleTimeout = builder.idleTimeout != null ? builder.idleTimeout : DEFAULT_IDLE_TIMEOUT;
230                this.sessionIdGenerator = builder.sessionIdGenerator != null ? builder.sessionIdGenerator : IdGenerator.defaultSessionInstance();
231                this.concurrentSessionLimit = builder.concurrentSessionLimit != null ? builder.concurrentSessionLimit : DEFAULT_CONCURRENT_SESSION_LIMIT;
232                this.sessions = new ConcurrentHashMap<>();
233                this.activeLimitedSessionIds = new ConcurrentHashMap<>();
234                this.activeLimitedSessionCount = new AtomicInteger(0);
235                this.pinnedSessionPredicate = sessionId -> false;
236                this.lastSweepAt = Instant.EPOCH;
237
238                if (this.idleTimeout.isNegative())
239                        throw new IllegalArgumentException("Idle timeout must not be negative.");
240
241                if (this.concurrentSessionLimit < 0)
242                        throw new IllegalArgumentException("Concurrent session limit must be >= 0");
243        }
244
245        @NonNull
246        @Override
247        public synchronized Optional<McpStoredSession> create(@NonNull Request request,
248                                                                                                                                                                                                                                @NonNull Class<? extends McpEndpoint> endpointClass) {
249                requireNonNull(request);
250                requireNonNull(endpointClass);
251                takeExpiredSessionsIfSweepDue();
252
253                if (!reserveSessionSlot())
254                        return Optional.empty();
255
256                String sessionId = this.sessionIdGenerator.generateId(request);
257
258                if (!DefaultMcpRuntime.isValidMcpSessionId(sessionId)) {
259                        releaseReservedSessionSlot();
260                        throw new IllegalStateException("MCP session ID generator produced an invalid session ID.");
261                }
262
263                takeExpiredSession(sessionId);
264
265                Instant now = Instant.now();
266                McpStoredSession session = new McpStoredSession(
267                                sessionId,
268                                endpointClass,
269                                now,
270                                now,
271                                false,
272                                false,
273                                null,
274                                null,
275                                null,
276                                McpSessionContext.fromBlankSlate(),
277                                null,
278                                0L
279                );
280
281                try {
282                        put(session);
283                } catch (Throwable throwable) {
284                        releaseReservedSessionSlot();
285                        throw throwable;
286                }
287
288                this.activeLimitedSessionIds.put(sessionId, Boolean.TRUE);
289                return Optional.of(session);
290        }
291
292        synchronized void create(@NonNull McpStoredSession session) {
293                requireNonNull(session);
294                takeExpiredSessionsIfSweepDue();
295
296                boolean slotReserved = false;
297
298                if (session.terminatedAt() == null) {
299                        if (!reserveSessionSlot())
300                                throw new IllegalStateException("MCP session limit reached.");
301
302                        slotReserved = true;
303                }
304
305                try {
306                        put(session);
307                } catch (Throwable throwable) {
308                        if (slotReserved)
309                                releaseReservedSessionSlot();
310
311                        throw throwable;
312                }
313
314                if (slotReserved)
315                        this.activeLimitedSessionIds.put(session.sessionId(), Boolean.TRUE);
316        }
317
318        @NonNull
319        @Override
320        public synchronized Optional<McpStoredSession> findBySessionId(@NonNull String sessionId) {
321                requireNonNull(sessionId);
322
323                McpStoredSession storedSession = this.sessions.get(sessionId);
324
325                if (storedSession == null)
326                        return Optional.empty();
327
328                if (isExpired(storedSession)) {
329                        if (this.sessions.remove(sessionId, storedSession))
330                                releaseSessionSlot(sessionId);
331
332                        return Optional.empty();
333                }
334
335                return Optional.of(storedSession);
336        }
337
338        @NonNull
339        @Override
340        public synchronized Boolean replace(@NonNull McpStoredSession expected,
341                                                                                                                                                         @NonNull McpStoredSession updated) {
342                requireNonNull(expected);
343                requireNonNull(updated);
344
345                if (!expected.sessionId().equals(updated.sessionId()))
346                        throw new IllegalArgumentException("Expected and updated sessions must have the same session ID.");
347
348                if (updated.version().longValue() <= expected.version().longValue())
349                        throw new IllegalArgumentException("Updated session version must be strictly greater than expected version.");
350
351                if (isExpired(expected))
352                        return false;
353
354                boolean replaced = this.sessions.replace(expected.sessionId(), expected, updated);
355
356                if (replaced && expected.terminatedAt() == null && updated.terminatedAt() != null)
357                        releaseSessionSlot(updated.sessionId());
358
359                return replaced;
360        }
361
362        @Override
363        public synchronized void deleteBySessionId(@NonNull String sessionId) {
364                requireNonNull(sessionId);
365
366                if (this.sessions.remove(sessionId) != null)
367                        releaseSessionSlot(sessionId);
368        }
369
370        synchronized void pinnedSessionPredicate(@NonNull Predicate<String> pinnedSessionPredicate) {
371                requireNonNull(pinnedSessionPredicate);
372                this.pinnedSessionPredicate = pinnedSessionPredicate;
373        }
374
375        synchronized boolean containsSessionId(@NonNull String sessionId) {
376                requireNonNull(sessionId);
377                return this.sessions.containsKey(sessionId);
378        }
379
380        @NonNull
381        synchronized Optional<McpStoredSession> takeExpiredSession(@NonNull String sessionId) {
382                requireNonNull(sessionId);
383
384                McpStoredSession storedSession = this.sessions.get(sessionId);
385
386                if (storedSession == null || !isExpired(storedSession))
387                        return Optional.empty();
388
389                if (!this.sessions.remove(sessionId, storedSession))
390                        return Optional.empty();
391
392                releaseSessionSlot(sessionId);
393                return Optional.of(storedSession);
394        }
395
396        @NonNull
397        synchronized List<McpStoredSession> takeExpiredSessionsIfSweepDue() {
398                if (ZERO.equals(this.idleTimeout))
399                        return List.of();
400
401                Instant now = Instant.now();
402                Duration sweepInterval = sweepInterval();
403
404                if (between(this.lastSweepAt, now).compareTo(sweepInterval) < 0)
405                        return List.of();
406
407                this.lastSweepAt = now;
408
409                List<McpStoredSession> expiredSessions = new ArrayList<>();
410
411                for (var entry : this.sessions.entrySet()) {
412                        McpStoredSession storedSession = entry.getValue();
413
414                        if (isExpired(storedSession) && this.sessions.remove(entry.getKey(), storedSession)) {
415                                expiredSessions.add(storedSession);
416                                releaseSessionSlot(entry.getKey());
417                        }
418                }
419
420                return expiredSessions;
421        }
422
423        private void put(@NonNull McpStoredSession session) {
424                requireNonNull(session);
425
426                McpStoredSession previous = this.sessions.putIfAbsent(session.sessionId(), session);
427
428                if (previous != null)
429                        throw new IllegalStateException("Session with ID '%s' already exists".formatted(session.sessionId()));
430        }
431
432        private boolean reserveSessionSlot() {
433                if (this.concurrentSessionLimit == 0)
434                        return true;
435
436                while (true) {
437                        int current = this.activeLimitedSessionCount.get();
438
439                        if (current >= this.concurrentSessionLimit)
440                                return false;
441
442                        if (this.activeLimitedSessionCount.compareAndSet(current, current + 1))
443                                return true;
444                }
445        }
446
447        private void releaseSessionSlot(@NonNull String sessionId) {
448                requireNonNull(sessionId);
449
450                if (this.activeLimitedSessionIds.remove(sessionId) != null)
451                        releaseReservedSessionSlot();
452        }
453
454        private void releaseReservedSessionSlot() {
455                while (true) {
456                        int current = this.activeLimitedSessionCount.get();
457
458                        if (current <= 0)
459                                return;
460
461                        if (this.activeLimitedSessionCount.compareAndSet(current, current - 1))
462                                return;
463                }
464        }
465
466        private boolean isExpired(@NonNull McpStoredSession storedSession) {
467                requireNonNull(storedSession);
468
469                if (ZERO.equals(this.idleTimeout))
470                        return false;
471
472                if (storedSession.terminatedAt() != null)
473                        return false;
474
475                if (this.pinnedSessionPredicate.test(storedSession.sessionId()))
476                        return false;
477
478                Duration idleDuration = between(storedSession.lastActivityAt(), Instant.now());
479                return idleDuration.compareTo(this.idleTimeout) > 0;
480        }
481
482        @NonNull
483        private Duration sweepInterval() {
484                return this.idleTimeout.compareTo(DEFAULT_SWEEP_INTERVAL) < 0 ? this.idleTimeout : DEFAULT_SWEEP_INTERVAL;
485        }
486}