001/*
002 * Copyright 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet;
018
019import com.soklet.SseRequestResult.HandshakeAccepted;
020import com.soklet.SseRequestResult.HandshakeRejected;
021import com.soklet.annotation.SseEventSource;
022import com.soklet.internal.spring.LinkedCaseInsensitiveMap;
023import org.jspecify.annotations.NonNull;
024import org.jspecify.annotations.Nullable;
025
026import javax.annotation.concurrent.ThreadSafe;
027import java.io.BufferedReader;
028import java.io.IOException;
029import java.io.InputStreamReader;
030import java.lang.reflect.InvocationTargetException;
031import java.nio.charset.Charset;
032import java.nio.charset.StandardCharsets;
033import java.time.Duration;
034import java.time.Instant;
035import java.time.ZoneId;
036import java.time.format.DateTimeFormatter;
037import java.util.ArrayList;
038import java.util.Collections;
039import java.util.EnumSet;
040import java.util.HashMap;
041import java.util.LinkedHashMap;
042import java.util.List;
043import java.util.Locale;
044import java.util.Map;
045import java.util.Map.Entry;
046import java.util.Objects;
047import java.util.Optional;
048import java.util.Set;
049import java.util.concurrent.ConcurrentHashMap;
050import java.util.concurrent.CopyOnWriteArrayList;
051import java.util.concurrent.CopyOnWriteArraySet;
052import java.util.concurrent.CountDownLatch;
053import java.util.concurrent.atomic.AtomicBoolean;
054import java.util.concurrent.atomic.AtomicReference;
055import java.util.concurrent.locks.ReentrantLock;
056import java.util.function.BiConsumer;
057import java.util.function.Consumer;
058import java.util.function.Function;
059import java.util.stream.Collectors;
060
061import static com.soklet.Utilities.emptyByteArray;
062import static com.soklet.Utilities.extractContentTypeFromHeaders;
063import static java.lang.String.format;
064import static java.util.Objects.requireNonNull;
065
066/**
067 * Soklet's main class - manages one or more configured transport servers ({@link HttpServer}, {@link SseServer}, and/or {@link McpServer})
068 * using the provided system configuration.
069 * <p>
070 * <pre>{@code // Use out-of-the-box defaults
071 * SokletConfig config = SokletConfig.withHttpServer(
072 *   HttpServer.fromPort(8080)
073 * ).build();
074 *
075 * try (Soklet soklet = Soklet.fromConfig(config)) {
076 *   soklet.start();
077 *   System.out.println("Soklet started, press [enter] to exit");
078 *   soklet.awaitShutdown(ShutdownTrigger.ENTER_KEY);
079 * }}</pre>
080 * <p>
081 * Soklet also offers an off-network {@link Simulator} concept via {@link #runSimulator(SokletConfig, Consumer)}, useful for integration testing.
082 * <p>
083 * Given a <em>Resource Method</em>...
084 * <pre>{@code public class HelloResource {
085 *   @GET("/hello")
086 *   public String hello(@QueryParameter String name) {
087 *     return String.format("Hello, %s", name);
088 *   }
089 * }}</pre>
090 * ...we might test it like this:
091 * <pre>{@code @Test
092 * public void integrationTest() {
093 *   // Just use your app's existing configuration
094 *   SokletConfig config = obtainMySokletConfig();
095 *
096 *   // Instead of running on a real HTTP server that listens on a port,
097 *   // a non-network Simulator is provided against which you can
098 *   // issue requests and receive responses.
099 *   Soklet.runSimulator(config, (simulator -> {
100 *     // Construct a request
101 *     Request request = Request.withPath(HttpMethod.GET, "/hello")
102 *       .queryParameters(Map.of("name", Set.of("Mark")))
103 *       .build();
104 *
105 *     // Perform the request and get a handle to the response
106 *     HttpRequestResult result = simulator.performHttpRequest(request);
107 *     MarshaledResponse marshaledResponse = result.getMarshaledResponse();
108 *
109 *     // Verify status code
110 *     Integer expectedCode = 200;
111 *     Integer actualCode = marshaledResponse.getStatusCode();
112 *     assertEquals(expectedCode, actualCode, "Bad status code");
113 *
114 *     // Verify response body
115 *     marshaledResponse.getBody().ifPresentOrElse(body -> {
116 *       String expectedBody = "Hello, Mark";
117 *       byte[] bytes = ((MarshaledResponseBody.Bytes) body).getBytes();
118 *       String actualBody = new String(bytes, StandardCharsets.UTF_8);
119 *       assertEquals(expectedBody, actualBody, "Bad response body");
120 *     }, () -> {
121 *       Assertions.fail("No response body");
122 *     });
123 *   }));
124 * }}</pre>
125 * <p>
126 * The {@link Simulator} also supports Server-Sent Events.
127 * <p>
128 * Integration testing documentation is available at <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a>.
129 *
130 * @author <a href="https://www.revetkn.com">Mark Allen</a>
131 */
132@ThreadSafe
133public final class Soklet implements AutoCloseable {
134        @NonNull
135        private static final Map<@NonNull String, @NonNull Set<@NonNull String>> DEFAULT_ACCEPTED_HANDSHAKE_HEADERS;
136
137        static {
138                // Generally speaking, we always want these headers for SSE streaming responses.
139                // Users can override if they think necessary
140                LinkedCaseInsensitiveMap<Set<String>> defaultAcceptedHandshakeHeaders = new LinkedCaseInsensitiveMap<>(4);
141                defaultAcceptedHandshakeHeaders.put("Content-Type", Set.of("text/event-stream; charset=UTF-8"));
142                defaultAcceptedHandshakeHeaders.put("Cache-Control", Set.of("no-cache", "no-transform"));
143                defaultAcceptedHandshakeHeaders.put("Connection", Set.of("keep-alive"));
144                defaultAcceptedHandshakeHeaders.put("X-Accel-Buffering", Set.of("no"));
145
146                DEFAULT_ACCEPTED_HANDSHAKE_HEADERS = Collections.unmodifiableMap(defaultAcceptedHandshakeHeaders);
147        }
148
149        /**
150         * Acquires a Soklet instance with the given configuration.
151         *
152         * @param sokletConfig configuration that drives the Soklet system
153         * @return a Soklet instance
154         */
155        @NonNull
156        public static Soklet fromConfig(@NonNull SokletConfig sokletConfig) {
157                requireNonNull(sokletConfig);
158                return new Soklet(sokletConfig);
159        }
160
161        @NonNull
162        private final SokletConfig sokletConfig;
163        @NonNull
164        private final ReentrantLock lock;
165        @NonNull
166        private final AtomicReference<CountDownLatch> awaitShutdownLatchReference;
167        @NonNull
168        private final DefaultMcpRuntime defaultMcpRuntime;
169
170        /**
171         * Creates a Soklet instance with the given configuration.
172         *
173         * @param sokletConfig configuration that drives the Soklet system
174         */
175        private Soklet(@NonNull SokletConfig sokletConfig) {
176                requireNonNull(sokletConfig);
177
178                this.sokletConfig = sokletConfig;
179                this.lock = new ReentrantLock();
180                this.awaitShutdownLatchReference = new AtomicReference<>(new CountDownLatch(1));
181                this.defaultMcpRuntime = new DefaultMcpRuntime(this);
182
183                sokletConfig.getMcpServer()
184                                .map(McpServer::getSessionStore)
185                                .filter(DefaultMcpSessionStore.class::isInstance)
186                                .map(DefaultMcpSessionStore.class::cast)
187                                .ifPresent(sessionStore -> sessionStore.pinnedSessionPredicate(this.defaultMcpRuntime::hasActiveStream));
188
189                sokletConfig.getMcpServer()
190                                .map(mcpServer -> mcpServer instanceof McpServerProxy mcpServerProxy ? mcpServerProxy.getRealImplementation() : mcpServer)
191                                .filter(DefaultMcpServer.class::isInstance)
192                                .map(DefaultMcpServer.class::cast)
193                                .ifPresent(defaultMcpServer -> defaultMcpServer.mcpRuntime(this.defaultMcpRuntime));
194
195                Set<ResourceMethod> resourceMethods = sokletConfig.getResourceMethodResolver().getResourceMethods();
196
197                // Fail fast in the event that Soklet appears misconfigured
198                if (resourceMethods.size() == 0
199                                && sokletConfig.getMcpServer().isEmpty())
200                        throw new IllegalStateException(format("No Soklet Resource Methods were found. First, try to rebuild and see if that solves the problem. If not, please ensure your %s is configured correctly. "
201                                        + "See https://www.soklet.com/docs/request-handling#resource-method-resolution for details.", ResourceMethodResolver.class.getSimpleName()));
202
203                boolean hasStandardHttpResourceMethods = resourceMethods.stream()
204                                .anyMatch(resourceMethod -> !resourceMethod.isSseEventSource());
205
206                if (hasStandardHttpResourceMethods && sokletConfig.getHttpServer().isEmpty())
207                        throw new IllegalStateException(format("Resource Methods were found, but no %s is configured. See https://www.soklet.com/docs/server-configuration for details.",
208                                        HttpServer.class.getSimpleName()));
209
210                // SSE misconfiguration check: @SseEventSource resource methods are declared, but not SseServer exists
211                boolean hasSseResourceMethods = resourceMethods.stream()
212                                .anyMatch(resourceMethod -> resourceMethod.isSseEventSource());
213
214                if (hasSseResourceMethods && sokletConfig.getSseServer().isEmpty())
215                        throw new IllegalStateException(format("Resource Methods annotated with @%s were found, but no %s is configured. See https://www.soklet.com/docs/server-sent-events for details.",
216                                        SseEventSource.class.getSimpleName(), SseServer.class.getSimpleName()));
217
218                MetricsCollector metricsCollector = sokletConfig.getMetricsCollector();
219
220                if (metricsCollector instanceof DefaultMetricsCollector defaultMetricsCollector) {
221                        try {
222                                defaultMetricsCollector.initialize(sokletConfig);
223                        } catch (Throwable t) {
224                                sokletConfig.getLifecycleObserver().didReceiveLogEvent(
225                                                LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED,
226                                                                                format("An exception occurred while initializing %s", metricsCollector.getClass().getSimpleName()))
227                                                                .throwable(t)
228                                                                .build());
229                        }
230                }
231
232                // Use a layer of indirection here so the Soklet type does not need to directly implement the `RequestHandler` interface.
233                // Reasoning: the `handleRequest` method for Soklet should not be public, which might lead to accidental invocation by users.
234                // That method should only be called by the managed `HttpServer` instance.
235                Soklet soklet = this;
236
237                sokletConfig.getHttpServer().ifPresent(server -> server.initialize(getSokletConfig(), (request, marshaledResponseConsumer) -> {
238                        // Delegate to Soklet's internal request handling method
239                        soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer);
240                }));
241
242                SseServer sseServer = sokletConfig.getSseServer().orElse(null);
243
244                if (sseServer != null)
245                        sseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> {
246                                // Delegate to Soklet's internal request handling method
247                                soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer);
248                        });
249
250                McpServer mcpServer = sokletConfig.getMcpServer().orElse(null);
251
252                if (mcpServer != null)
253                        mcpServer.initialize(sokletConfig, soklet::handleMcpRequest);
254        }
255
256        /**
257         * Starts the managed server instance[s].
258         * <p>
259         * If the managed server[s] are already started, this is a no-op.
260         */
261        public void start() {
262                getLock().lock();
263
264                try {
265                        if (isStarted())
266                                return;
267
268                        getAwaitShutdownLatchReference().set(new CountDownLatch(1));
269
270                        SokletConfig sokletConfig = getSokletConfig();
271                        LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver();
272
273                                // 1. Notify global intent to start
274                                lifecycleObserver.willStartSoklet(this);
275
276                                HttpServer httpServer = sokletConfig.getHttpServer().orElse(null);
277                                SseServer sseServer = sokletConfig.getSseServer().orElse(null);
278                                McpServer mcpServer = sokletConfig.getMcpServer().orElse(null);
279                                boolean httpServerStarted = false;
280                                boolean sseServerStarted = false;
281                                boolean mcpServerStarted = false;
282
283                                try {
284                                        // 2. Attempt to start Main HttpServer
285                                        if (httpServer != null) {
286                                                lifecycleObserver.willStartHttpServer(httpServer);
287                                                try {
288                                                        httpServer.start();
289                                                        httpServerStarted = true;
290                                                        lifecycleObserver.didStartHttpServer(httpServer);
291                                                } catch (Throwable t) {
292                                                        lifecycleObserver.didFailToStartHttpServer(httpServer, t);
293                                                        throw t; // Rethrow to trigger outer catch block
294                                                }
295                                        }
296
297                                        // 3. Attempt to start SSE HttpServer (if present)
298                                        if (sseServer != null) {
299                                                lifecycleObserver.willStartSseServer(sseServer);
300                                                try {
301                                                        sseServer.start();
302                                                        sseServerStarted = true;
303                                                        lifecycleObserver.didStartSseServer(sseServer);
304                                                } catch (Throwable t) {
305                                                        lifecycleObserver.didFailToStartSseServer(sseServer, t);
306                                                        throw t; // Rethrow to trigger outer catch block
307                                                }
308                                        }
309
310                                        if (mcpServer != null) {
311                                                lifecycleObserver.willStartMcpServer(mcpServer);
312                                                try {
313                                                        mcpServer.start();
314                                                        mcpServerStarted = true;
315                                                        lifecycleObserver.didStartMcpServer(mcpServer);
316                                                } catch (Throwable t) {
317                                                        lifecycleObserver.didFailToStartMcpServer(mcpServer, t);
318                                                        throw t;
319                                                }
320                                }
321
322                                        // 4. Global success
323                                        lifecycleObserver.didStartSoklet(this);
324                                } catch (Throwable t) {
325                                        rollbackStartedServersAfterFailedStart(lifecycleObserver,
326                                                        httpServer, httpServerStarted,
327                                                        sseServer, sseServerStarted,
328                                                        mcpServer, mcpServerStarted,
329                                                        t);
330
331                                        // 5. Global failure
332                                        lifecycleObserver.didFailToStartSoklet(this, t);
333
334                                        // Ensure the exception bubbles up so the application knows startup failed
335                                if (t instanceof RuntimeException)
336                                        throw (RuntimeException) t;
337
338                                throw new RuntimeException(t);
339                        }
340                } finally {
341                        getLock().unlock();
342                }
343        }
344
345        private void rollbackStartedServersAfterFailedStart(@NonNull LifecycleObserver lifecycleObserver,
346                                                                                                                                                                                                                @Nullable HttpServer httpServer,
347                                                                                                                                                                                                                boolean httpServerStarted,
348                                                                                                                                                                                                                @Nullable SseServer sseServer,
349                                                                                                                                                                                                                boolean sseServerStarted,
350                                                                                                                                                                                                                @Nullable McpServer mcpServer,
351                                                                                                                                                                                                                boolean mcpServerStarted,
352                                                                                                                                                                                                                @NonNull Throwable startupFailure) {
353                requireNonNull(lifecycleObserver);
354                requireNonNull(startupFailure);
355
356                if (mcpServerStarted && mcpServer != null)
357                        stopStartedMcpServerForRollback(lifecycleObserver, mcpServer, startupFailure);
358
359                if (sseServerStarted && sseServer != null)
360                        stopStartedSseServerForRollback(lifecycleObserver, sseServer, startupFailure);
361
362                if (httpServerStarted && httpServer != null)
363                        stopStartedHttpServerForRollback(lifecycleObserver, httpServer, startupFailure);
364
365                CountDownLatch awaitShutdownLatch = getAwaitShutdownLatchReference().get();
366
367                if (awaitShutdownLatch != null)
368                        awaitShutdownLatch.countDown();
369        }
370
371        private void stopStartedHttpServerForRollback(@NonNull LifecycleObserver lifecycleObserver,
372                                                                                                                                                                                                @NonNull HttpServer httpServer,
373                                                                                                                                                                                                @NonNull Throwable startupFailure) {
374                requireNonNull(lifecycleObserver);
375                requireNonNull(httpServer);
376                requireNonNull(startupFailure);
377
378                try {
379                        lifecycleObserver.willStopHttpServer(httpServer);
380                } catch (Throwable t) {
381                        startupFailure.addSuppressed(t);
382                }
383
384                try {
385                        httpServer.stop();
386                        try {
387                                lifecycleObserver.didStopHttpServer(httpServer);
388                        } catch (Throwable t) {
389                                startupFailure.addSuppressed(t);
390                        }
391                } catch (Throwable t) {
392                        startupFailure.addSuppressed(t);
393
394                        try {
395                                lifecycleObserver.didFailToStopHttpServer(httpServer, t);
396                        } catch (Throwable t2) {
397                                startupFailure.addSuppressed(t2);
398                        }
399                }
400        }
401
402        private void stopStartedSseServerForRollback(@NonNull LifecycleObserver lifecycleObserver,
403                                                                                                                                                                                         @NonNull SseServer sseServer,
404                                                                                                                                                                                         @NonNull Throwable startupFailure) {
405                requireNonNull(lifecycleObserver);
406                requireNonNull(sseServer);
407                requireNonNull(startupFailure);
408
409                try {
410                        lifecycleObserver.willStopSseServer(sseServer);
411                } catch (Throwable t) {
412                        startupFailure.addSuppressed(t);
413                }
414
415                try {
416                        sseServer.stop();
417                        try {
418                                lifecycleObserver.didStopSseServer(sseServer);
419                        } catch (Throwable t) {
420                                startupFailure.addSuppressed(t);
421                        }
422                } catch (Throwable t) {
423                        startupFailure.addSuppressed(t);
424
425                        try {
426                                lifecycleObserver.didFailToStopSseServer(sseServer, t);
427                        } catch (Throwable t2) {
428                                startupFailure.addSuppressed(t2);
429                        }
430                }
431        }
432
433        private void stopStartedMcpServerForRollback(@NonNull LifecycleObserver lifecycleObserver,
434                                                                                                                                                                                         @NonNull McpServer mcpServer,
435                                                                                                                                                                                         @NonNull Throwable startupFailure) {
436                requireNonNull(lifecycleObserver);
437                requireNonNull(mcpServer);
438                requireNonNull(startupFailure);
439
440                try {
441                        lifecycleObserver.willStopMcpServer(mcpServer);
442                } catch (Throwable t) {
443                        startupFailure.addSuppressed(t);
444                }
445
446                try {
447                        mcpServer.stop();
448                        try {
449                                lifecycleObserver.didStopMcpServer(mcpServer);
450                        } catch (Throwable t) {
451                                startupFailure.addSuppressed(t);
452                        }
453                } catch (Throwable t) {
454                        startupFailure.addSuppressed(t);
455
456                        try {
457                                lifecycleObserver.didFailToStopMcpServer(mcpServer, t);
458                        } catch (Throwable t2) {
459                                startupFailure.addSuppressed(t2);
460                        }
461                }
462        }
463
464        /**
465         * Stops the managed server instance[s].
466         * <p>
467         * If the managed server[s] are already stopped, this is a no-op.
468         */
469        public void stop() {
470                getLock().lock();
471
472                try {
473                        if (isStarted()) {
474                                SokletConfig sokletConfig = getSokletConfig();
475                                LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver();
476
477                                // 1. Notify global intent to stop
478                                lifecycleObserver.willStopSoklet(this);
479
480                                Throwable firstEncounteredException = null;
481                                        HttpServer httpServer = sokletConfig.getHttpServer().orElse(null);
482
483                                        // 2. Attempt to stop Main HttpServer
484                                        if (httpServer != null && httpServer.isStarted()) {
485                                                lifecycleObserver.willStopHttpServer(httpServer);
486                                                try {
487                                                        httpServer.stop();
488                                                lifecycleObserver.didStopHttpServer(httpServer);
489                                        } catch (Throwable t) {
490                                                firstEncounteredException = t;
491                                                lifecycleObserver.didFailToStopHttpServer(httpServer, t);
492                                        }
493                                }
494
495                                // 3. Attempt to stop SSE HttpServer
496                                SseServer sseServer = sokletConfig.getSseServer().orElse(null);
497
498                                if (sseServer != null && sseServer.isStarted()) {
499                                        lifecycleObserver.willStopSseServer(sseServer);
500                                        try {
501                                                sseServer.stop();
502                                                lifecycleObserver.didStopSseServer(sseServer);
503                                        } catch (Throwable t) {
504                                                if (firstEncounteredException == null)
505                                                        firstEncounteredException = t;
506
507                                                lifecycleObserver.didFailToStopSseServer(sseServer, t);
508                                        }
509                                }
510
511                                McpServer mcpServer = sokletConfig.getMcpServer().orElse(null);
512
513                                if (mcpServer != null && mcpServer.isStarted()) {
514                                        lifecycleObserver.willStopMcpServer(mcpServer);
515                                        try {
516                                                mcpServer.stop();
517                                                lifecycleObserver.didStopMcpServer(mcpServer);
518                                        } catch (Throwable t) {
519                                                if (firstEncounteredException == null)
520                                                        firstEncounteredException = t;
521
522                                                lifecycleObserver.didFailToStopMcpServer(mcpServer, t);
523                                        }
524                                }
525
526                                // 4. Global completion (Success or Failure)
527                                if (firstEncounteredException == null)
528                                        lifecycleObserver.didStopSoklet(this);
529                                else
530                                        lifecycleObserver.didFailToStopSoklet(this, firstEncounteredException);
531                        }
532                } finally {
533                        try {
534                                getAwaitShutdownLatchReference().get().countDown();
535                        } finally {
536                                getLock().unlock();
537                        }
538                }
539        }
540
541        /**
542         * Blocks the current thread until JVM shutdown ({@code SIGTERM/SIGINT/System.exit(...)} and so forth), <strong>or</strong> if one of the provided {@code shutdownTriggers} occurs.
543         * <p>
544         * This method will automatically invoke this instance's {@link #stop()} method once it becomes unblocked.
545         * <p>
546         * <strong>Notes regarding {@link ShutdownTrigger#ENTER_KEY}:</strong>
547         * <ul>
548         *   <li>It will invoke {@link #stop()} on <i>all</i> Soklet instances, as stdin is process-wide</li>
549         *   <li>It is only supported for environments with an interactive TTY and will be ignored if none exists (e.g. running in a Docker container) - Soklet will detect this and fire {@link LifecycleObserver#didReceiveLogEvent(LogEvent)} with an event of type {@link LogEventType#CONFIGURATION_UNSUPPORTED}</li>
550         * </ul>
551         *
552         * @param shutdownTriggers additional trigger[s] which signal that shutdown should occur, e.g. {@link ShutdownTrigger#ENTER_KEY} for "enter key pressed"
553         * @throws InterruptedException if the current thread has its interrupted status set on entry to this method, or is interrupted while waiting
554         */
555        public void awaitShutdown(@Nullable ShutdownTrigger... shutdownTriggers) throws InterruptedException {
556                Thread shutdownHook = null;
557                boolean registeredEnterKeyShutdownTrigger = false;
558                Set<ShutdownTrigger> shutdownTriggersAsSet = shutdownTriggers == null || shutdownTriggers.length == 0 ? Set.of() : EnumSet.copyOf(Set.of(shutdownTriggers));
559
560                try {
561                        // Optionally listen for enter key
562                        if (shutdownTriggersAsSet.contains(ShutdownTrigger.ENTER_KEY)) {
563                                registeredEnterKeyShutdownTrigger = KeypressManager.register(this); // returns false if stdin unusable/disabled
564
565                                if (!registeredEnterKeyShutdownTrigger) {
566                                        LogEvent logEvent = LogEvent.with(
567                                                        LogEventType.CONFIGURATION_UNSUPPORTED,
568                                                        format("Ignoring request for %s.%s - it is unsupported in this environment (no interactive TTY detected)", ShutdownTrigger.class.getSimpleName(), ShutdownTrigger.ENTER_KEY.name())
569                                        ).build();
570
571                                        getSokletConfig().getLifecycleObserver().didReceiveLogEvent(logEvent);
572                                }
573                        }
574
575                        // Always register a shutdown hook
576                        shutdownHook = new Thread(() -> {
577                                try {
578                                        stop();
579                                } catch (Throwable ignored) {
580                                        // Nothing to do
581                                }
582                        }, "soklet-shutdown-hook");
583
584                        Runtime.getRuntime().addShutdownHook(shutdownHook);
585
586                        // Wait until "close" finishes
587                        getAwaitShutdownLatchReference().get().await();
588                } finally {
589                        if (registeredEnterKeyShutdownTrigger)
590                                KeypressManager.unregister(this);
591
592                        try {
593                                Runtime.getRuntime().removeShutdownHook(shutdownHook);
594                        } catch (IllegalStateException ignored) {
595                                // JVM shutting down
596                        }
597                }
598        }
599
600        /**
601         * Handles "awaitShutdown" for {@link ShutdownTrigger#ENTER_KEY} by listening to stdin - all Soklet instances are terminated on keypress.
602         */
603        @ThreadSafe
604        private static final class KeypressManager {
605                @NonNull
606                private static final Set<@NonNull Soklet> SOKLET_REGISTRY;
607                @NonNull
608                private static final AtomicBoolean LISTENER_STARTED;
609
610                static {
611                        SOKLET_REGISTRY = new CopyOnWriteArraySet<>();
612                        LISTENER_STARTED = new AtomicBoolean(false);
613                }
614
615                /**
616                 * Register a Soklet for Enter-to-stop support. Returns true iff a listener is (or was already) active.
617                 * If System.in is not usable (or disabled), returns false and does nothing.
618                 */
619                @NonNull
620                synchronized static Boolean register(@NonNull Soklet soklet) {
621                        requireNonNull(soklet);
622
623                        // If stdin is not readable (e.g., container with no TTY), don't start a listener.
624                        if (!canReadFromStdin())
625                                return false;
626
627                        SOKLET_REGISTRY.add(soklet);
628
629                        // Start a single process-wide listener once.
630                        if (LISTENER_STARTED.compareAndSet(false, true)) {
631                                Thread thread = new Thread(KeypressManager::runLoop, "soklet-keypress-shutdown-listener");
632                                thread.setDaemon(true); // never block JVM exit
633                                thread.start();
634                        }
635
636                        return true;
637                }
638
639                synchronized static void unregister(@NonNull Soklet soklet) {
640                        SOKLET_REGISTRY.remove(soklet);
641                        // We intentionally keep the listener alive; it's daemon and cheap.
642                        // If stdin hits EOF, the listener exits on its own.
643                }
644
645                /**
646                 * Heuristic: if System.in is present and calling available() doesn't throw,
647                 * treat it as readable. Works even in IDEs where System.console() is null.
648                 */
649                @NonNull
650                private static Boolean canReadFromStdin() {
651                        if (System.in == null)
652                                return false;
653
654                        try {
655                                // available() >= 0 means stream is open; 0 means no buffered data (that’s fine).
656                                return System.in.available() >= 0;
657                        } catch (IOException e) {
658                                return false;
659                        }
660                }
661
662                /**
663                 * Single blocking read on stdin. On any line (or EOF), stop all registered servers.
664                 */
665                private static void runLoop() {
666                        try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8))) {
667                                // Blocks until newline or EOF; EOF (null) happens with /dev/null or closed pipe.
668                                bufferedReader.readLine();
669                                stopAllSoklets();
670                        } catch (Throwable ignored) {
671                                // If stdin is closed mid-run, just exit quietly.
672                        }
673                }
674
675                synchronized private static void stopAllSoklets() {
676                        // Either a line or EOF → stop everything that’s currently registered.
677                        for (Soklet soklet : SOKLET_REGISTRY) {
678                                try {
679                                        soklet.stop();
680                                } catch (Throwable ignored) {
681                                        // Nothing to do
682                                }
683                        }
684                }
685
686                private KeypressManager() {}
687        }
688
689        /**
690         * Nonpublic "informal" implementation of {@link com.soklet.HttpServer.RequestHandler} so Soklet does not need to expose {@code handleRequest} publicly.
691         * Reasoning: users of this library should never call {@code handleRequest} directly - it should only be invoked in response to events
692         * provided by a {@link HttpServer} or {@link SseServer} implementation.
693         */
694        protected void handleRequest(@NonNull Request request,
695                                                                                                                         @NonNull ServerType serverType,
696                                                                                                                         @NonNull Consumer<HttpRequestResult> requestResultConsumer) {
697                requireNonNull(request);
698                requireNonNull(serverType);
699                requireNonNull(requestResultConsumer);
700
701                Instant processingStarted = Instant.now();
702
703                SokletConfig sokletConfig = getSokletConfig();
704                ResourceMethodResolver resourceMethodResolver = sokletConfig.getResourceMethodResolver();
705                ResponseMarshaler responseMarshaler = sokletConfig.getResponseMarshaler();
706                LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver();
707                RequestInterceptor requestInterceptor = sokletConfig.getRequestInterceptor();
708                MetricsCollector metricsCollector = sokletConfig.getMetricsCollector();
709
710                // Holders to permit mutable effectively-final variables
711                AtomicReference<MarshaledResponse> marshaledResponseHolder = new AtomicReference<>();
712                AtomicReference<Throwable> resourceMethodResolutionExceptionHolder = new AtomicReference<>();
713                AtomicReference<Request> requestHolder = new AtomicReference<>(request);
714                AtomicReference<ResourceMethod> resourceMethodHolder = new AtomicReference<>();
715                AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>();
716
717                // Holders to permit mutable effectively-final state tracking
718                AtomicBoolean willStartResponseWritingCompleted = new AtomicBoolean(false);
719                AtomicBoolean didFinishResponseWritingCompleted = new AtomicBoolean(false);
720                AtomicBoolean didFinishRequestHandlingCompleted = new AtomicBoolean(false);
721                AtomicBoolean didInvokeWrapRequestConsumer = new AtomicBoolean(false);
722
723                List<Throwable> throwables = new ArrayList<>(10);
724
725                Consumer<LogEvent> safelyLog = (logEvent -> {
726                        try {
727                                lifecycleObserver.didReceiveLogEvent(logEvent);
728                        } catch (Throwable throwable) {
729                                throwable.printStackTrace();
730                                throwables.add(throwable);
731                        }
732                });
733
734                BiConsumer<String, Consumer<MetricsCollector>> safelyCollectMetrics = (message, metricsInvocation) -> {
735                        if (metricsCollector == null)
736                                return;
737
738                        try {
739                                metricsInvocation.accept(metricsCollector);
740                        } catch (Throwable throwable) {
741                                safelyLog.accept(LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED, message)
742                                                .throwable(throwable)
743                                                .request(requestHolder.get())
744                                                .resourceMethod(resourceMethodHolder.get())
745                                                .marshaledResponse(marshaledResponseHolder.get())
746                                                .build());
747                        }
748                };
749
750                requestHolder.set(request);
751
752                try {
753                        requestInterceptor.wrapRequest(serverType, request, (wrappedRequest) -> {
754                                didInvokeWrapRequestConsumer.set(true);
755                                requestHolder.set(wrappedRequest);
756
757                                try {
758                                        // Resolve after wrapping so path/method rewrites affect routing.
759                                        resourceMethodHolder.set(resourceMethodResolver.resourceMethodForRequest(requestHolder.get(), serverType).orElse(null));
760                                        resourceMethodResolutionExceptionHolder.set(null);
761                                } catch (Throwable t) {
762                                        safelyLog.accept(LogEvent.with(LogEventType.RESOURCE_METHOD_RESOLUTION_FAILED, "Unable to resolve Resource Method")
763                                                        .throwable(t)
764                                                        .request(requestHolder.get())
765                                                        .build());
766
767                                        // If an exception occurs here, keep track of it - we will surface them after letting LifecycleObserver
768                                        // see that a request has come in.
769                                        throwables.add(t);
770                                        resourceMethodResolutionExceptionHolder.set(t);
771                                        resourceMethodHolder.set(null);
772                                }
773
774                                try {
775                                        lifecycleObserver.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get());
776                                } catch (Throwable t) {
777                                        safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_START_REQUEST_HANDLING_FAILED,
778                                                                        format("An exception occurred while invoking %s::didStartRequestHandling",
779                                                                                        LifecycleObserver.class.getSimpleName()))
780                                                        .throwable(t)
781                                                        .request(requestHolder.get())
782                                                        .resourceMethod(resourceMethodHolder.get())
783                                                        .build());
784
785                                        throwables.add(t);
786                                }
787
788                                safelyCollectMetrics.accept(
789                                                format("An exception occurred while invoking %s::didStartRequestHandling", MetricsCollector.class.getSimpleName()),
790                                                (metricsInvocation) -> metricsInvocation.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get()));
791
792                                try {
793                                        AtomicBoolean didInvokeMarshaledResponseConsumer = new AtomicBoolean(false);
794
795                                        requestInterceptor.interceptRequest(serverType, requestHolder.get(), resourceMethodHolder.get(), (interceptorRequest) -> {
796                                                requestHolder.set(interceptorRequest);
797
798                                                try {
799                                                        if (resourceMethodResolutionExceptionHolder.get() != null)
800                                                                throw resourceMethodResolutionExceptionHolder.get();
801
802                                                        HttpRequestResult requestResult = toHttpRequestResult(requestHolder.get(), resourceMethodHolder.get(), serverType);
803                                                        requestResultHolder.set(requestResult);
804
805                                                        MarshaledResponse originalMarshaledResponse = requestResult.getMarshaledResponse();
806                                                        MarshaledResponse updatedMarshaledResponse = requestResult.getMarshaledResponse();
807
808                                                        // A few special cases that are "global" in that they can affect all requests and
809                                                        // need to happen after marshaling the response...
810
811                                                        // 1. Customize response for HEAD (e.g. remove body, set Content-Length header)
812                                                        updatedMarshaledResponse = applyHeadResponseIfApplicable(requestHolder.get(), updatedMarshaledResponse);
813
814                                                        // 2. Apply other standard response customizations (CORS, Content-Length)
815                                                        // Note that we don't want to write Content-Length for SSE "accepted" handshakes
816                                                        SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null);
817                                                        boolean suppressContentLength = sseHandshakeResult != null && sseHandshakeResult instanceof SseHandshakeResult.Accepted;
818
819                                                        updatedMarshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), updatedMarshaledResponse, suppressContentLength);
820
821                                                        // Update our result holder with the modified response if necessary
822                                                        if (originalMarshaledResponse != updatedMarshaledResponse) {
823                                                                marshaledResponseHolder.set(updatedMarshaledResponse);
824                                                                requestResultHolder.set(requestResult.copy()
825                                                                                .marshaledResponse(updatedMarshaledResponse)
826                                                                                .finish());
827                                                        }
828
829                                                        return updatedMarshaledResponse;
830                                                        } catch (Throwable t) {
831                                                                if (t != resourceMethodResolutionExceptionHolder.get()) {
832                                                                        throwables.add(t);
833
834                                                                safelyLog.accept(LogEvent.with(LogEventType.REQUEST_PROCESSING_FAILED,
835                                                                                                "An exception occurred while processing request")
836                                                                                .throwable(t)
837                                                                                .request(requestHolder.get())
838                                                                                .resourceMethod(resourceMethodHolder.get())
839                                                                                .build());
840                                                        }
841
842                                                        // Unhappy path.  Try to use configuration's exception response marshaler...
843                                                        try {
844                                                                MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get());
845                                                                marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse);
846                                                                marshaledResponseHolder.set(marshaledResponse);
847
848                                                                return marshaledResponse;
849                                                        } catch (Throwable t2) {
850                                                                throwables.add(t2);
851
852                                                                safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED,
853                                                                                                format("An exception occurred while trying to write an exception response for %s", t))
854                                                                                .throwable(t2)
855                                                                                .request(requestHolder.get())
856                                                                                .resourceMethod(resourceMethodHolder.get())
857                                                                                .build());
858
859                                                                // The configuration's exception response marshaler failed - provide a failsafe response to recover
860                                                                return provideFailsafeMarshaledResponse(requestHolder.get(), t2);
861                                                        }
862                                                }
863                                        }, (interceptorMarshaledResponse) -> {
864                                                requireNonNull(interceptorMarshaledResponse);
865                                                didInvokeMarshaledResponseConsumer.set(true);
866                                                marshaledResponseHolder.set(interceptorMarshaledResponse);
867                                        });
868
869                                        if (!didInvokeMarshaledResponseConsumer.get()) {
870                                                requestResultHolder.set(null);
871                                                throw new IllegalStateException(format("%s::interceptRequest must call responseWriter", RequestInterceptor.class.getSimpleName()));
872                                        }
873                                } catch (Throwable t) {
874                                        throwables.add(t);
875
876                                        try {
877                                                // In the event that an error occurs during processing of a RequestInterceptor method, for example
878                                                safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_INTERCEPT_REQUEST_FAILED,
879                                                                                format("An exception occurred while invoking %s::interceptRequest", RequestInterceptor.class.getSimpleName()))
880                                                                .throwable(t)
881                                                                .request(requestHolder.get())
882                                                                .resourceMethod(resourceMethodHolder.get())
883                                                                .build());
884
885                                                MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get());
886                                                marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse);
887                                                marshaledResponseHolder.set(marshaledResponse);
888                                        } catch (Throwable t2) {
889                                                throwables.add(t2);
890
891                                                safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED,
892                                                                                format("An exception occurred while invoking %s::forThrowable when trying to write an exception response for %s", ResponseMarshaler.class.getSimpleName(), t))
893                                                                .throwable(t2)
894                                                                .request(requestHolder.get())
895                                                                .resourceMethod(resourceMethodHolder.get())
896                                                                .build());
897
898                                                marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t2));
899                                        }
900                                } finally {
901                                        try {
902                                                try {
903                                                        lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get());
904                                                } finally {
905                                                        willStartResponseWritingCompleted.set(true);
906                                                }
907
908                                                safelyCollectMetrics.accept(
909                                                                format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()),
910                                                                (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()));
911
912                                                Instant responseWriteStarted = Instant.now();
913
914                                                try {
915                                                        HttpRequestResult requestResult = requestResultHolder.get();
916
917                                                        if (requestResult != null)
918                                                                requestResultConsumer.accept(requestResult);
919                                                        else
920                                                                requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get())
921                                                                                .resourceMethod(resourceMethodHolder.get())
922                                                                                .build());
923
924                                                        Instant responseWriteFinished = Instant.now();
925                                                        Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished);
926
927                                                        try {
928                                                                lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration);
929                                                        } catch (Throwable t) {
930                                                                throwables.add(t);
931
932                                                                safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED,
933                                                                                                format("An exception occurred while invoking %s::didWriteResponse",
934                                                                                                                LifecycleObserver.class.getSimpleName()))
935                                                                                .throwable(t)
936                                                                                .request(requestHolder.get())
937                                                                                .resourceMethod(resourceMethodHolder.get())
938                                                                                .marshaledResponse(marshaledResponseHolder.get())
939                                                                                .build());
940                                                        } finally {
941                                                                didFinishResponseWritingCompleted.set(true);
942                                                        }
943
944                                                        safelyCollectMetrics.accept(
945                                                                        format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()),
946                                                                        (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(),
947                                                                                        marshaledResponseHolder.get(), responseWriteDuration));
948                                                } catch (Throwable t) {
949                                                        throwables.add(t);
950
951                                                        Instant responseWriteFinished = Instant.now();
952                                                        Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished);
953
954                                                        try {
955                                                                lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t);
956                                                        } catch (Throwable t2) {
957                                                                throwables.add(t2);
958
959                                                                safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED,
960                                                                                                format("An exception occurred while invoking %s::didFailToWriteResponse",
961                                                                                                                LifecycleObserver.class.getSimpleName()))
962                                                                                .throwable(t2)
963                                                                                .request(requestHolder.get())
964                                                                                .resourceMethod(resourceMethodHolder.get())
965                                                                                .marshaledResponse(marshaledResponseHolder.get())
966                                                                                .build());
967                                                        }
968
969                                                        safelyCollectMetrics.accept(
970                                                                        format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()),
971                                                                        (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(),
972                                                                                        marshaledResponseHolder.get(), responseWriteDuration, t));
973                                                }
974                                        } finally {
975                                                Duration processingDuration = Duration.between(processingStarted, Instant.now());
976
977                                                safelyCollectMetrics.accept(
978                                                                format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()),
979                                                                (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)));
980
981                                                try {
982                                                        lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables));
983                                                } catch (Throwable t) {
984                                                        safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED,
985                                                                                        format("An exception occurred while invoking %s::didFinishRequestHandling",
986                                                                                                        LifecycleObserver.class.getSimpleName()))
987                                                                        .throwable(t)
988                                                                        .request(requestHolder.get())
989                                                                        .resourceMethod(resourceMethodHolder.get())
990                                                                        .marshaledResponse(marshaledResponseHolder.get())
991                                                                        .build());
992                                                } finally {
993                                                        didFinishRequestHandlingCompleted.set(true);
994                                                }
995                                        }
996                                }
997                        });
998
999                        if (!didInvokeWrapRequestConsumer.get())
1000                                throw new IllegalStateException(format("%s::wrapRequest must call requestProcessor", RequestInterceptor.class.getSimpleName()));
1001                } catch (Throwable t) {
1002                        // If an error occurred during request wrapping, it's possible a response was never written/communicated back to LifecycleObserver.
1003                        // Detect that here and inform LifecycleObserver accordingly.
1004                        safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_WRAP_REQUEST_FAILED,
1005                                                        format("An exception occurred while invoking %s::wrapRequest",
1006                                                                        RequestInterceptor.class.getSimpleName()))
1007                                        .throwable(t)
1008                                        .request(requestHolder.get())
1009                                        .resourceMethod(resourceMethodHolder.get())
1010                                        .marshaledResponse(marshaledResponseHolder.get())
1011                                        .build());
1012
1013                        // If we don't have a response, let the marshaler try to make one for the exception.
1014                        // If that fails, use the failsafe.
1015                        if (marshaledResponseHolder.get() == null) {
1016                                try {
1017                                        MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get());
1018                                        marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse);
1019                                        marshaledResponseHolder.set(marshaledResponse);
1020                                } catch (Throwable t2) {
1021                                        throwables.add(t2);
1022
1023                                        safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED,
1024                                                                        format("An exception occurred during request wrapping while invoking %s::forThrowable",
1025                                                                                        ResponseMarshaler.class.getSimpleName()))
1026                                                        .throwable(t2)
1027                                                        .request(requestHolder.get())
1028                                                        .resourceMethod(resourceMethodHolder.get())
1029                                                        .marshaledResponse(marshaledResponseHolder.get())
1030                                                        .build());
1031
1032                                        marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t));
1033                                }
1034                        }
1035
1036                        if (!willStartResponseWritingCompleted.get()) {
1037                                try {
1038                                        lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get());
1039                                } catch (Throwable t2) {
1040                                        safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_WILL_WRITE_RESPONSE_FAILED,
1041                                                                        format("An exception occurred while invoking %s::willWriteResponse",
1042                                                                                        LifecycleObserver.class.getSimpleName()))
1043                                                        .throwable(t2)
1044                                                        .request(requestHolder.get())
1045                                                        .resourceMethod(resourceMethodHolder.get())
1046                                                        .marshaledResponse(marshaledResponseHolder.get())
1047                                                        .build());
1048                                }
1049
1050                                safelyCollectMetrics.accept(
1051                                                format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()),
1052                                                (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()));
1053                        }
1054
1055                        try {
1056                                Instant responseWriteStarted = Instant.now();
1057
1058                                if (!didFinishResponseWritingCompleted.get()) {
1059                                        try {
1060                                                HttpRequestResult requestResult = requestResultHolder.get();
1061
1062                                                if (requestResult != null)
1063                                                        requestResultConsumer.accept(requestResult);
1064                                                else
1065                                                        requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get())
1066                                                                        .resourceMethod(resourceMethodHolder.get())
1067                                                                        .build());
1068
1069                                                Instant responseWriteFinished = Instant.now();
1070                                                Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished);
1071
1072                                                try {
1073                                                        lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration);
1074                                                } catch (Throwable t2) {
1075                                                        throwables.add(t2);
1076
1077                                                        safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED,
1078                                                                                        format("An exception occurred while invoking %s::didWriteResponse",
1079                                                                                                        LifecycleObserver.class.getSimpleName()))
1080                                                                        .throwable(t2)
1081                                                                        .request(requestHolder.get())
1082                                                                        .resourceMethod(resourceMethodHolder.get())
1083                                                                        .marshaledResponse(marshaledResponseHolder.get())
1084                                                                        .build());
1085                                                }
1086
1087                                                safelyCollectMetrics.accept(
1088                                                                format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()),
1089                                                                (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(),
1090                                                                                marshaledResponseHolder.get(), responseWriteDuration));
1091                                        } catch (Throwable t2) {
1092                                                throwables.add(t2);
1093
1094                                                Instant responseWriteFinished = Instant.now();
1095                                                Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished);
1096
1097                                                try {
1098                                                        lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t);
1099                                                } catch (Throwable t3) {
1100                                                        throwables.add(t3);
1101
1102                                                        safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED,
1103                                                                                        format("An exception occurred while invoking %s::didFailToWriteResponse",
1104                                                                                                        LifecycleObserver.class.getSimpleName()))
1105                                                                        .throwable(t3)
1106                                                                        .request(requestHolder.get())
1107                                                                        .resourceMethod(resourceMethodHolder.get())
1108                                                                        .marshaledResponse(marshaledResponseHolder.get())
1109                                                                        .build());
1110                                                }
1111
1112                                                safelyCollectMetrics.accept(
1113                                                                format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()),
1114                                                                (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(),
1115                                                                                marshaledResponseHolder.get(), responseWriteDuration, t));
1116                                        }
1117                                }
1118                        } finally {
1119                                if (!didFinishRequestHandlingCompleted.get()) {
1120                                        Duration processingDuration = Duration.between(processingStarted, Instant.now());
1121
1122                                        safelyCollectMetrics.accept(
1123                                                        format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()),
1124                                                        (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)));
1125
1126                                        try {
1127                                                lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables));
1128                                        } catch (Throwable t2) {
1129                                                safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED,
1130                                                                                format("An exception occurred while invoking %s::didFinishRequestHandling",
1131                                                                                                LifecycleObserver.class.getSimpleName()))
1132                                                                .throwable(t2)
1133                                                                .request(requestHolder.get())
1134                                                                .resourceMethod(resourceMethodHolder.get())
1135                                                                .marshaledResponse(marshaledResponseHolder.get())
1136                                                                .build());
1137                                        }
1138                                }
1139                        }
1140                }
1141        }
1142
1143        @NonNull
1144        protected HttpRequestResult toHttpRequestResult(@NonNull Request request,
1145                                                                                                                                                                        @Nullable ResourceMethod resourceMethod,
1146                                                                                                                                                                        @NonNull ServerType serverType) throws Throwable {
1147                requireNonNull(request);
1148                requireNonNull(serverType);
1149
1150                ResourceMethodParameterProvider resourceMethodParameterProvider = getSokletConfig().getResourceMethodParameterProvider();
1151                InstanceProvider instanceProvider = getSokletConfig().getInstanceProvider();
1152                CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer();
1153                ResourceMethodResolver resourceMethodResolver = getSokletConfig().getResourceMethodResolver();
1154                ResponseMarshaler responseMarshaler = getSokletConfig().getResponseMarshaler();
1155                CorsPreflight corsPreflight = request.getCorsPreflight().orElse(null);
1156
1157                // Special short-circuit for big requests
1158                if (request.isContentTooLarge())
1159                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forContentTooLarge(request, resourceMethodResolver.resourceMethodForRequest(request, serverType).orElse(null)))
1160                                        .resourceMethod(resourceMethod)
1161                                        .build();
1162
1163                // Special short-circuit for OPTIONS *
1164                if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH)
1165                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptionsSplat(request)).build();
1166
1167                // No resource method was found for this HTTP method and path.
1168                if (resourceMethod == null) {
1169                        // If this was an OPTIONS request, do special processing.
1170                        // If not, figure out if we should return a 404 or 405.
1171                        if (request.getHttpMethod() == HttpMethod.OPTIONS) {
1172                                // See what methods are available to us for this request's path
1173                                Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType);
1174
1175                                // Special handling for CORS preflight requests, if needed
1176                                if (corsPreflight != null) {
1177                                        // Let configuration function determine if we should authorize this request.
1178                                        // Discard any OPTIONS references - see https://stackoverflow.com/a/68529748
1179                                        Map<HttpMethod, ResourceMethod> nonOptionsMatchingResourceMethodsByHttpMethod = matchingResourceMethodsByHttpMethod.entrySet().stream()
1180                                                        .filter(entry -> entry.getKey() != HttpMethod.OPTIONS)
1181                                                        .collect(Collectors.toMap(Entry::getKey, Entry::getValue));
1182
1183                                        CorsPreflightResponse corsPreflightResponse = corsAuthorizer.authorizePreflight(request, corsPreflight, nonOptionsMatchingResourceMethodsByHttpMethod).orElse(null);
1184
1185                                        // Allow or reject CORS depending on what the function said to do
1186                                        if (corsPreflightResponse != null) {
1187                                                // Allow
1188                                                MarshaledResponse marshaledResponse = responseMarshaler.forCorsPreflightAllowed(request, corsPreflight, corsPreflightResponse);
1189
1190                                                return HttpRequestResult.withMarshaledResponse(marshaledResponse)
1191                                                                .corsPreflightResponse(corsPreflightResponse)
1192                                                                .resourceMethod(resourceMethod)
1193                                                                .build();
1194                                        }
1195
1196                                        // Reject
1197                                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forCorsPreflightRejected(request, corsPreflight))
1198                                                        .resourceMethod(resourceMethod)
1199                                                        .build();
1200                                } else {
1201                                        // Just a normal OPTIONS response (non-CORS-preflight).
1202                                        // If there's a matching OPTIONS resource method for this OPTIONS request, then invoke it.
1203                                        ResourceMethod optionsResourceMethod = matchingResourceMethodsByHttpMethod.get(HttpMethod.OPTIONS);
1204
1205                                        if (optionsResourceMethod != null) {
1206                                                resourceMethod = optionsResourceMethod;
1207                                        } else {
1208                                                Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(matchingResourceMethodsByHttpMethod, true);
1209
1210                                                return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptions(request, allowedHttpMethods))
1211                                                                .resourceMethod(resourceMethod)
1212                                                                .build();
1213                                        }
1214                                }
1215                        } else if (request.getHttpMethod() == HttpMethod.HEAD) {
1216                                // If there's a matching GET resource method for this HEAD request, then invoke it
1217                                Request headGetRequest = request.copy().httpMethod(HttpMethod.GET).finish();
1218                                ResourceMethod headGetResourceMethod = resourceMethodResolver.resourceMethodForRequest(headGetRequest, serverType).orElse(null);
1219
1220                                if (headGetResourceMethod != null)
1221                                        resourceMethod = headGetResourceMethod;
1222                                else
1223                                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request))
1224                                                        .resourceMethod(resourceMethod)
1225                                                        .build();
1226                        } else {
1227                                // Not an OPTIONS request, so it's possible we have a 405. See if other HTTP methods match...
1228                                Map<HttpMethod, ResourceMethod> otherMatchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType);
1229
1230                                Set<HttpMethod> matchingNonOptionsHttpMethods = otherMatchingResourceMethodsByHttpMethod.keySet().stream()
1231                                                .filter(httpMethod -> httpMethod != HttpMethod.OPTIONS)
1232                                                .collect(Collectors.toSet());
1233
1234                                if (matchingNonOptionsHttpMethods.size() > 0) {
1235                                        // ...if some do, it's a 405
1236                                        Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(otherMatchingResourceMethodsByHttpMethod, true);
1237                                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forMethodNotAllowed(request, allowedHttpMethods))
1238                                                        .resourceMethod(resourceMethod)
1239                                                        .build();
1240                                } else {
1241                                        // no matching resource method found, it's a 404
1242                                        return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request))
1243                                                        .resourceMethod(resourceMethod)
1244                                                        .build();
1245                                }
1246                        }
1247                }
1248
1249                // Found a resource method - happy path.
1250                // 1. Get an instance of the resource class
1251                // 2. Get values to pass to the resource method on the resource class
1252                // 3. Invoke the resource method and use its return value to drive a response
1253                Class<?> resourceClass = resourceMethod.getMethod().getDeclaringClass();
1254                Object resourceClassInstance;
1255
1256                try {
1257                        resourceClassInstance = instanceProvider.provide(resourceClass);
1258                } catch (Exception e) {
1259                        throw new IllegalArgumentException(format("Unable to acquire an instance of %s", resourceClass.getName()), e);
1260                }
1261
1262                List<Object> parameterValues = resourceMethodParameterProvider.parameterValuesForResourceMethod(request, resourceMethod);
1263
1264                Object responseObject;
1265
1266                try {
1267                        responseObject = resourceMethod.getMethod().invoke(resourceClassInstance, parameterValues.toArray());
1268                } catch (InvocationTargetException e) {
1269                        if (e.getTargetException() != null)
1270                                throw e.getTargetException();
1271
1272                        throw e;
1273                }
1274
1275                // Unwrap the Optional<T>, if one exists.  We do not recurse deeper than one level
1276                if (responseObject instanceof Optional<?>)
1277                        responseObject = ((Optional<?>) responseObject).orElse(null);
1278
1279                Response response;
1280                SseHandshakeResult sseHandshakeResult = null;
1281
1282                // If null/void return, it's a 204
1283                // If it's a MarshaledResponse object, no marshaling + return it immediately - caller knows exactly what it wants to write.
1284                // If it's a Response object, use as is.
1285                // If it's a non-Response type of object, assume it's the response body and wrap in a Response.
1286                if (responseObject == null) {
1287                        response = Response.withStatusCode(204).build();
1288                } else if (responseObject instanceof MarshaledResponse) {
1289                        MarshaledResponse marshaledResponse = (MarshaledResponse) responseObject;
1290                        enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent());
1291
1292                        return HttpRequestResult.withMarshaledResponse(marshaledResponse)
1293                                        .resourceMethod(resourceMethod)
1294                                        .build();
1295                } else if (responseObject instanceof Response) {
1296                        response = (Response) responseObject;
1297                } else if (responseObject instanceof SseHandshakeResult.Accepted accepted) { // SSE "accepted" handshake
1298                        return HttpRequestResult.withMarshaledResponse(toMarshaledResponse(accepted))
1299                                        .resourceMethod(resourceMethod)
1300                                        .sseHandshakeResult(accepted)
1301                                        .build();
1302                } else if (responseObject instanceof SseHandshakeResult.Rejected rejected) { // SSE "rejected" handshake
1303                        response = rejected.getResponse();
1304                        sseHandshakeResult = rejected;
1305                } else {
1306                        response = Response.withStatusCode(200).body(responseObject).build();
1307                }
1308
1309                enforceBodylessStatusCode(response.getStatusCode(), response.getBody().isPresent());
1310
1311                MarshaledResponse marshaledResponse = responseMarshaler.forResourceMethod(request, response, resourceMethod);
1312
1313                enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent());
1314
1315                return HttpRequestResult.withMarshaledResponse(marshaledResponse)
1316                                .response(response)
1317                                .resourceMethod(resourceMethod)
1318                                .sseHandshakeResult(sseHandshakeResult)
1319                                .build();
1320        }
1321
1322        @NonNull
1323        private MarshaledResponse toMarshaledResponse(SseHandshakeResult.@NonNull Accepted accepted) {
1324                requireNonNull(accepted);
1325
1326                Map<String, Set<String>> headers = accepted.getHeaders();
1327                LinkedCaseInsensitiveMap<Set<String>> finalHeaders = new LinkedCaseInsensitiveMap<>(DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.size() + headers.size());
1328
1329                // Start with defaults
1330                for (Map.Entry<String, Set<String>> e : DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.entrySet())
1331                        finalHeaders.put(e.getKey(), e.getValue()); // values already unmodifiable
1332
1333                // Overlay user-supplied headers (prefer user values on key collision)
1334                for (Map.Entry<String, Set<String>> e : headers.entrySet()) {
1335                        // Defensively copy so callers can't mutate after construction
1336                        Set<String> values = e.getValue() == null ? Set.of() : Set.copyOf(e.getValue());
1337                        finalHeaders.put(e.getKey(), values);
1338                }
1339
1340                return MarshaledResponse.withStatusCode(200)
1341                                .headers(finalHeaders)
1342                                .cookies(accepted.getCookies())
1343                                .build();
1344        }
1345
1346        private static void enforceBodylessStatusCode(@NonNull Integer statusCode,
1347                                                                                                                                                                                                @NonNull Boolean hasBody) {
1348                requireNonNull(statusCode);
1349                requireNonNull(hasBody);
1350
1351                if (hasBody && isBodylessStatusCode(statusCode))
1352                        throw new IllegalStateException(format("HTTP status code %d must not include a response body", statusCode));
1353        }
1354
1355        private static boolean isBodylessStatusCode(@NonNull Integer statusCode) {
1356                requireNonNull(statusCode);
1357                return (statusCode >= 100 && statusCode < 200) || statusCode == 204 || statusCode == 304;
1358        }
1359
1360        @NonNull
1361        protected MarshaledResponse applyHeadResponseIfApplicable(@NonNull Request request,
1362                                                                                                                                                                                                                                                @NonNull MarshaledResponse marshaledResponse) {
1363                if (request.getHttpMethod() != HttpMethod.HEAD)
1364                        return marshaledResponse;
1365
1366                return getSokletConfig().getResponseMarshaler().forHead(request, marshaledResponse);
1367        }
1368
1369        // Hat tip to Aslan Parçası and GrayStar
1370        @NonNull
1371        protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request,
1372                                                                                                                                                                                                                                                                                         @NonNull MarshaledResponse marshaledResponse) {
1373                requireNonNull(request);
1374                requireNonNull(marshaledResponse);
1375
1376                return applyCommonPropertiesToMarshaledResponse(request, marshaledResponse, false);
1377        }
1378
1379        protected void handleMcpRequest(@NonNull Request request,
1380                                                                                                                                        @NonNull Consumer<HttpRequestResult> requestResultConsumer) {
1381                requireNonNull(request);
1382                requireNonNull(requestResultConsumer);
1383
1384                requestResultConsumer.accept(this.defaultMcpRuntime.handleRequest(request));
1385        }
1386
1387        protected void handleSimulatedMcpStreamDisconnect(@NonNull Request request,
1388                                                                                                                                                                                                                @NonNull String sessionId) {
1389                requireNonNull(request);
1390                requireNonNull(sessionId);
1391
1392                this.defaultMcpRuntime.handleClientDisconnectedStream(request, sessionId);
1393        }
1394
1395        @NonNull
1396        protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request,
1397                                                                                                                                                                                                                                                                                         @NonNull MarshaledResponse marshaledResponse,
1398                                                                                                                                                                                                                                                                                         @NonNull Boolean suppressContentLength) {
1399                requireNonNull(request);
1400                requireNonNull(marshaledResponse);
1401                requireNonNull(suppressContentLength);
1402
1403                // Don't write Content-Length for an accepted SSE Handshake, for example
1404                if (!suppressContentLength)
1405                        marshaledResponse = applyContentLengthIfApplicable(request, marshaledResponse);
1406
1407                // If the Date header is missing, add it using our cached provider
1408                if (!marshaledResponse.getHeaders().containsKey("Date"))
1409                        marshaledResponse = marshaledResponse.copy()
1410                                        .headers(headers -> headers.put("Date", Set.of(CachedHttpDate.getCurrentValue())))
1411                                        .finish();
1412
1413                marshaledResponse = applyCorsResponseIfApplicable(request, marshaledResponse);
1414
1415                return marshaledResponse;
1416        }
1417
1418        @NonNull
1419        protected MarshaledResponse applyContentLengthIfApplicable(@NonNull Request request,
1420                                                                                                                                                                                                                                                 @NonNull MarshaledResponse marshaledResponse) {
1421                requireNonNull(request);
1422                requireNonNull(marshaledResponse);
1423
1424                Set<String> normalizedHeaderNames = marshaledResponse.getHeaders().keySet().stream()
1425                                .map(headerName -> headerName.toLowerCase(Locale.US))
1426                                .collect(Collectors.toSet());
1427
1428                // If Content-Length is already specified, don't do anything
1429                if (normalizedHeaderNames.contains("content-length"))
1430                        return marshaledResponse;
1431
1432                // If Content-Length is not specified, specify as the number of bytes in the body
1433                return marshaledResponse.copy()
1434                                .headers((mutableHeaders) -> {
1435                                        String contentLengthHeaderValue = String.valueOf(marshaledResponse.getBodyLength());
1436                                        mutableHeaders.put("Content-Length", Set.of(contentLengthHeaderValue));
1437                                }).finish();
1438        }
1439
1440        @NonNull
1441        protected MarshaledResponse applyCorsResponseIfApplicable(@NonNull Request request,
1442                                                                                                                                                                                                                                                @NonNull MarshaledResponse marshaledResponse) {
1443                requireNonNull(request);
1444                requireNonNull(marshaledResponse);
1445
1446                Cors cors = request.getCors().orElse(null);
1447
1448                // If non-CORS request, nothing further to do (note that CORS preflight was handled earlier)
1449                if (cors == null)
1450                        return marshaledResponse;
1451
1452                CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer();
1453
1454                // Does the authorizer say we are authorized?
1455                CorsResponse corsResponse = corsAuthorizer.authorize(request, cors).orElse(null);
1456
1457                // Not authorized - don't apply CORS headers to the response
1458                if (corsResponse == null)
1459                        return marshaledResponse;
1460
1461                // Authorized - OK, let's apply the headers to the response
1462                return getSokletConfig().getResponseMarshaler().forCorsAllowed(request, cors, corsResponse, marshaledResponse);
1463        }
1464
1465        @NonNull
1466        protected Map<@NonNull HttpMethod, @NonNull ResourceMethod> resolveMatchingResourceMethodsByHttpMethod(@NonNull Request request,
1467                                                                                                                                                                                                                                                                                                                                                                                                                                 @NonNull ResourceMethodResolver resourceMethodResolver,
1468                                                                                                                                                                                                                                                                                                                                                                                                                                 @NonNull ServerType serverType) {
1469                requireNonNull(request);
1470                requireNonNull(resourceMethodResolver);
1471                requireNonNull(serverType);
1472
1473                // Special handling for OPTIONS *
1474                if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH)
1475                        return new LinkedHashMap<>();
1476
1477                Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = new LinkedHashMap<>(HttpMethod.values().length);
1478
1479                for (HttpMethod httpMethod : HttpMethod.values()) {
1480                        // Make a quick copy of the request to see if other paths match
1481                        Request otherRequest = Request.withPath(httpMethod, request.getPath()).build();
1482                        ResourceMethod resourceMethod = resourceMethodResolver.resourceMethodForRequest(otherRequest, serverType).orElse(null);
1483
1484                        if (resourceMethod != null)
1485                                matchingResourceMethodsByHttpMethod.put(httpMethod, resourceMethod);
1486                }
1487
1488                return matchingResourceMethodsByHttpMethod;
1489        }
1490
1491        @NonNull
1492        private static Set<@NonNull HttpMethod> allowedHttpMethodsForResponse(@NonNull Map<@NonNull HttpMethod, @NonNull ResourceMethod> matchingResourceMethodsByHttpMethod,
1493                                                                                                                                                                                                                                                                                                @NonNull Boolean includeOptions) {
1494                requireNonNull(matchingResourceMethodsByHttpMethod);
1495                requireNonNull(includeOptions);
1496
1497                Set<HttpMethod> allowedHttpMethods = EnumSet.noneOf(HttpMethod.class);
1498                allowedHttpMethods.addAll(matchingResourceMethodsByHttpMethod.keySet());
1499
1500                if (includeOptions)
1501                        allowedHttpMethods.add(HttpMethod.OPTIONS);
1502
1503                if (matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.GET) || matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.HEAD))
1504                        allowedHttpMethods.add(HttpMethod.HEAD);
1505
1506                return allowedHttpMethods;
1507        }
1508
1509        @NonNull
1510        protected MarshaledResponse provideFailsafeMarshaledResponse(@NonNull Request request,
1511                                                                                                                                                                                                                                                         @NonNull Throwable throwable) {
1512                requireNonNull(request);
1513                requireNonNull(throwable);
1514
1515                Integer statusCode = 500;
1516                Charset charset = StandardCharsets.UTF_8;
1517
1518                return MarshaledResponse.withStatusCode(statusCode)
1519                                .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", charset.name()))))
1520                                .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(charset))
1521                                .build();
1522        }
1523
1524        /**
1525         * Synonym for {@link #stop()}.
1526         */
1527        @Override
1528        public void close() {
1529                stop();
1530        }
1531
1532        /**
1533         * Is any managed transport server started?
1534         *
1535         * @return {@code true} if at least one configured transport server is started, {@code false} otherwise
1536         */
1537        @NonNull
1538        public Boolean isStarted() {
1539                getLock().lock();
1540
1541                try {
1542                        HttpServer httpServer = getSokletConfig().getHttpServer().orElse(null);
1543
1544                        if (httpServer != null && httpServer.isStarted())
1545                                return true;
1546
1547                        SseServer sseServer = getSokletConfig().getSseServer().orElse(null);
1548                        if (sseServer != null && sseServer.isStarted())
1549                                return true;
1550
1551                        McpServer mcpServer = getSokletConfig().getMcpServer().orElse(null);
1552                        return mcpServer != null && mcpServer.isStarted();
1553                } finally {
1554                        getLock().unlock();
1555                }
1556        }
1557
1558        /**
1559         * Runs Soklet with special non-network "simulator" implementations of the configured transport servers - useful for integration testing.
1560         * <p>
1561         * See <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a> for how to write these tests.
1562         *
1563         * @param sokletConfig      configuration that drives the Soklet system
1564         * @param simulatorConsumer code to execute within the context of the simulator
1565         */
1566        public static void runSimulator(@NonNull SokletConfig sokletConfig,
1567                                                                                                                                        @NonNull Consumer<Simulator> simulatorConsumer) {
1568                requireNonNull(sokletConfig);
1569                requireNonNull(simulatorConsumer);
1570
1571                // Create Soklet instance - this initializes the REAL implementations through proxies
1572                Soklet soklet = Soklet.fromConfig(sokletConfig);
1573
1574                // Extract proxies (they're guaranteed to be proxies now)
1575                HttpServerProxy serverProxy = sokletConfig.getHttpServer()
1576                                .map(server -> (HttpServerProxy) server)
1577                                .orElse(null);
1578                SseServerProxy sseServerProxy = sokletConfig.getSseServer()
1579                                .map(s -> (SseServerProxy) s)
1580                                .orElse(null);
1581                McpServerProxy mcpServerProxy = sokletConfig.getMcpServer()
1582                                .map(mcpServer -> (McpServerProxy) mcpServer)
1583                                .orElse(null);
1584
1585                // Create mock implementations
1586                MockHttpServer mockServer = serverProxy == null ? null : new MockHttpServer();
1587                MockSseServer mockSseServer = new MockSseServer();
1588                MockMcpServer mockMcpServer = mcpServerProxy == null ? null : new MockMcpServer(mcpServerProxy.getRealImplementation());
1589
1590                // Switch proxies to simulator mode
1591                if (serverProxy != null)
1592                        serverProxy.enableSimulatorMode(mockServer);
1593
1594                if (sseServerProxy != null)
1595                        sseServerProxy.enableSimulatorMode(mockSseServer);
1596
1597                if (mcpServerProxy != null)
1598                        mcpServerProxy.enableSimulatorMode(mockMcpServer);
1599
1600                try {
1601                        // Initialize mocks with request handlers that delegate to Soklet's processing
1602                                if (mockServer != null)
1603                                        mockServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> {
1604                                                // Delegate to Soklet's internal request handling
1605                                                soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer);
1606                                        });
1607
1608                        if (mockSseServer != null)
1609                                mockSseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> {
1610                                        // Delegate to Soklet's internal request handling for SSE
1611                                        soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer);
1612                                });
1613
1614                        if (mockMcpServer != null)
1615                                mockMcpServer.initialize(sokletConfig, soklet::handleMcpRequest);
1616
1617                        if (mockMcpServer != null)
1618                                mockMcpServer.onClientDisconnectedMcpStream(soklet::handleSimulatedMcpStreamDisconnect);
1619
1620                        // Create and provide simulator
1621                        Simulator simulator = new DefaultSimulator(mockServer, mockSseServer, mockMcpServer);
1622                        simulatorConsumer.accept(simulator);
1623                } finally {
1624                        // Always restore to real implementations
1625                                if (serverProxy != null)
1626                                        serverProxy.disableSimulatorMode();
1627
1628                        if (sseServerProxy != null)
1629                                sseServerProxy.disableSimulatorMode();
1630
1631                        if (mcpServerProxy != null)
1632                                mcpServerProxy.disableSimulatorMode();
1633                }
1634        }
1635
1636        @NonNull
1637        protected SokletConfig getSokletConfig() {
1638                return this.sokletConfig;
1639        }
1640
1641        @NonNull
1642        protected ReentrantLock getLock() {
1643                return this.lock;
1644        }
1645
1646        @NonNull
1647        protected AtomicReference<CountDownLatch> getAwaitShutdownLatchReference() {
1648                return this.awaitShutdownLatchReference;
1649        }
1650
1651        @ThreadSafe
1652                static class DefaultSimulator implements Simulator {
1653                        @Nullable
1654                        private MockHttpServer server;
1655                @Nullable
1656                private MockSseServer sseServer;
1657                @Nullable
1658                private MockMcpServer mcpServer;
1659
1660                        public DefaultSimulator(@Nullable MockHttpServer server,
1661                                                                                                                        @Nullable MockSseServer sseServer,
1662                                                                                                                        @Nullable MockMcpServer mcpServer) {
1663                                this.server = server;
1664                                this.sseServer = sseServer;
1665                                this.mcpServer = mcpServer;
1666                }
1667
1668                @NonNull
1669                        @Override
1670                        public HttpRequestResult performHttpRequest(@NonNull Request request) {
1671                                MockHttpServer server = getHttpServer().orElse(null);
1672
1673                                if (server == null)
1674                                        throw new IllegalStateException(format("You must specify a %s in your %s to simulate requests",
1675                                                        HttpServer.class.getSimpleName(), SokletConfig.class.getSimpleName()));
1676
1677                                AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>();
1678                                HttpServer.RequestHandler requestHandler = server.getRequestHandler().orElse(null);
1679
1680                        if (requestHandler == null)
1681                                throw new IllegalStateException("You must register a request handler prior to simulating requests");
1682
1683                        requestHandler.handleRequest(request, (requestResult -> {
1684                                requestResultHolder.set(requestResult);
1685                        }));
1686
1687                        return requestResultHolder.get();
1688                }
1689
1690                @NonNull
1691                @Override
1692                public SseRequestResult performSseRequest(@NonNull Request request) {
1693                        MockSseServer sseServer = getSseServer().orElse(null);
1694
1695                        if (sseServer == null)
1696                                throw new IllegalStateException(format("You must specify a %s in your %s to simulate Server-Sent Event requests",
1697                                                SseServer.class.getSimpleName(), SokletConfig.class.getSimpleName()));
1698
1699                        AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>();
1700                        SseServer.RequestHandler requestHandler = sseServer.getRequestHandler().orElse(null);
1701
1702                        if (requestHandler == null)
1703                                throw new IllegalStateException("You must register a request handler prior to simulating SSE Event Source requests");
1704
1705                        requestHandler.handleRequest(request, (requestResult -> {
1706                                requestResultHolder.set(requestResult);
1707                        }));
1708
1709                        HttpRequestResult requestResult = requestResultHolder.get();
1710                        SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null);
1711
1712                        if (sseHandshakeResult == null)
1713                                return new SseRequestResult.RequestFailed(requestResult);
1714
1715                        if (sseHandshakeResult instanceof SseHandshakeResult.Accepted acceptedHandshake) {
1716                                Consumer<SseUnicaster> clientInitializer = acceptedHandshake.getClientInitializer().orElse(null);
1717
1718                                // Create a synthetic logical response using values from the accepted handshake
1719                                if (requestResult.getResponse().isEmpty())
1720                                        requestResult = requestResult.copy()
1721                                                        .response(Response.withStatusCode(200)
1722                                                                        .headers(acceptedHandshake.getHeaders())
1723                                                                        .cookies(acceptedHandshake.getCookies())
1724                                                                        .build())
1725                                                        .finish();
1726
1727                                HandshakeAccepted handshakeAccepted = new HandshakeAccepted(acceptedHandshake, request.getResourcePath(), requestResult, this, clientInitializer);
1728                                return handshakeAccepted;
1729                        }
1730
1731                        if (sseHandshakeResult instanceof SseHandshakeResult.Rejected rejectedHandshake)
1732                                return new HandshakeRejected(rejectedHandshake, requestResult);
1733
1734                        throw new IllegalStateException(format("Encountered unexpected %s: %s", SseHandshakeResult.class.getSimpleName(), sseHandshakeResult));
1735                }
1736
1737                @NonNull
1738                @Override
1739                public McpRequestResult performMcpRequest(@NonNull Request request) {
1740                        requireNonNull(request);
1741
1742                        MockMcpServer mcpServer = getMcpServer().orElse(null);
1743
1744                        if (mcpServer == null)
1745                                throw new IllegalStateException(format("You must specify an MCP server in your %s to simulate MCP requests",
1746                                                SokletConfig.class.getSimpleName()));
1747
1748                        AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>();
1749                        McpServer.RequestHandler requestHandler = mcpServer.getRequestHandler().orElse(null);
1750
1751                        if (requestHandler == null)
1752                                throw new IllegalStateException("You must register a request handler prior to simulating MCP requests");
1753
1754                        requestHandler.handleRequest(request, requestResultHolder::set);
1755
1756                        HttpRequestResult requestResult = requestResultHolder.get();
1757
1758                        if (requestResult == null)
1759                                throw new IllegalStateException("No MCP request result was produced by the simulator");
1760
1761                        if (extractContentTypeFromHeaders(requestResult.getMarshaledResponse().getHeaders())
1762                                        .filter(contentType -> contentType.equalsIgnoreCase("text/event-stream"))
1763                                        .isPresent()) {
1764                                McpRequestResult.StreamOpened streamOpened = new McpRequestResult.StreamOpened(
1765                                                requestResult,
1766                                                mcpServer.getMcpStreamErrorHandler(),
1767                                                requestResult.isMcpStreamClosedAfterReplay());
1768
1769                                for (McpObject mcpStreamMessage : requestResult.getMcpStreamMessages())
1770                                        streamOpened.emitMessage(mcpStreamMessage);
1771
1772                                if (!requestResult.isMcpStreamClosedAfterReplay())
1773                                        request.getHeader("MCP-Session-Id").ifPresent(sessionId -> mcpServer.registerOpenStream(sessionId, request, streamOpened));
1774
1775                                return streamOpened;
1776                        }
1777
1778                        if (request.getHttpMethod() == HttpMethod.DELETE
1779                                        && Objects.equals(requestResult.getMarshaledResponse().getStatusCode(), 204))
1780                                request.getHeader("MCP-Session-Id").ifPresent(mcpServer::terminateStreamsForSession);
1781
1782                        return new McpRequestResult.ResponseCompleted(requestResult);
1783                }
1784
1785                @NonNull
1786                @Override
1787                public Simulator onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) {
1788                        MockSseServer sseServer = getSseServer().orElse(null);
1789
1790                        if (sseServer != null)
1791                                sseServer.onBroadcastError(onBroadcastError);
1792
1793                        return this;
1794                }
1795
1796                @NonNull
1797                @Override
1798                public Simulator onUnicastError(@Nullable Consumer<Throwable> onUnicastError) {
1799                        MockSseServer sseServer = getSseServer().orElse(null);
1800
1801                        if (sseServer != null)
1802                                sseServer.onUnicastError(onUnicastError);
1803
1804                        return this;
1805                }
1806
1807                @NonNull
1808                @Override
1809                public Simulator onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) {
1810                        MockMcpServer mcpServer = getMcpServer().orElse(null);
1811
1812                        if (mcpServer != null)
1813                                mcpServer.onMcpStreamError(onMcpStreamError);
1814
1815                        return this;
1816                }
1817
1818                @NonNull
1819                Optional<MockHttpServer> getHttpServer() {
1820                        return Optional.ofNullable(this.server);
1821                }
1822
1823                @NonNull
1824                Optional<MockSseServer> getSseServer() {
1825                        return Optional.ofNullable(this.sseServer);
1826                }
1827
1828                @NonNull
1829                Optional<MockMcpServer> getMcpServer() {
1830                        return Optional.ofNullable(this.mcpServer);
1831                }
1832        }
1833
1834        /**
1835         * Mock server that doesn't touch the network at all, useful for testing.
1836         *
1837         * @author <a href="https://www.revetkn.com">Mark Allen</a>
1838         */
1839        @ThreadSafe
1840        static class MockHttpServer implements HttpServer {
1841                @Nullable
1842                private SokletConfig sokletConfig;
1843                private HttpServer.@Nullable RequestHandler requestHandler;
1844
1845                @Override
1846                public void start() {
1847                        // No-op
1848                }
1849
1850                @Override
1851                public void stop() {
1852                        // No-op
1853                }
1854
1855                @NonNull
1856                @Override
1857                public Boolean isStarted() {
1858                        return true;
1859                }
1860
1861                @Override
1862                        public void initialize(@NonNull SokletConfig sokletConfig,
1863                                                                                                                 @NonNull RequestHandler requestHandler) {
1864                                requireNonNull(sokletConfig);
1865                                requireNonNull(requestHandler);
1866
1867                                this.sokletConfig = sokletConfig;
1868                                this.requestHandler = requestHandler;
1869                        }
1870
1871                @NonNull
1872                protected Optional<SokletConfig> getSokletConfig() {
1873                        return Optional.ofNullable(this.sokletConfig);
1874                }
1875
1876                @NonNull
1877                protected Optional<RequestHandler> getRequestHandler() {
1878                        return Optional.ofNullable(this.requestHandler);
1879                }
1880        }
1881
1882        /**
1883         * Mock MCP server that doesn't touch the network at all, useful for testing.
1884         */
1885        @ThreadSafe
1886        static class MockMcpServer implements McpServer, InternalMcpSessionMessagePublisher {
1887                @NonNull
1888                private final McpServer realImplementation;
1889                @Nullable
1890                private SokletConfig sokletConfig;
1891                private McpServer.@Nullable RequestHandler requestHandler;
1892                @NonNull
1893                private final AtomicReference<Consumer<Throwable>> mcpStreamErrorHandler;
1894                @NonNull
1895                private final ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> openStreamsBySessionId;
1896                @NonNull
1897                private final AtomicReference<@Nullable BiConsumer<Request, String>> clientDisconnectedMcpStreamHandler;
1898
1899                public MockMcpServer(@NonNull McpServer realImplementation) {
1900                        requireNonNull(realImplementation);
1901
1902                        this.realImplementation = realImplementation;
1903                        this.mcpStreamErrorHandler = new AtomicReference<>();
1904                        this.openStreamsBySessionId = new ConcurrentHashMap<>();
1905                        this.clientDisconnectedMcpStreamHandler = new AtomicReference<>();
1906                }
1907
1908                @Override
1909                public void start() {
1910                        // No-op
1911                }
1912
1913                @Override
1914                public void stop() {
1915                        // No-op
1916                }
1917
1918                @NonNull
1919                @Override
1920                public Boolean isStarted() {
1921                        return true;
1922                }
1923
1924                @Override
1925                public void initialize(@NonNull SokletConfig sokletConfig,
1926                                                                                                         @NonNull RequestHandler requestHandler) {
1927                        requireNonNull(sokletConfig);
1928                        requireNonNull(requestHandler);
1929
1930                        this.sokletConfig = sokletConfig;
1931                        this.requestHandler = requestHandler;
1932                }
1933
1934                @NonNull
1935                @Override
1936                public McpHandlerResolver getHandlerResolver() {
1937                        return getRealImplementation().getHandlerResolver();
1938                }
1939
1940                @NonNull
1941                @Override
1942                public McpRequestAdmissionPolicy getRequestAdmissionPolicy() {
1943                        return getRealImplementation().getRequestAdmissionPolicy();
1944                }
1945
1946                @NonNull
1947                @Override
1948                public McpRequestInterceptor getRequestInterceptor() {
1949                        return getRealImplementation().getRequestInterceptor();
1950                }
1951
1952                @NonNull
1953                @Override
1954                public McpResponseMarshaler getResponseMarshaler() {
1955                        return getRealImplementation().getResponseMarshaler();
1956                }
1957
1958                @NonNull
1959                @Override
1960                public McpCorsAuthorizer getCorsAuthorizer() {
1961                        return getRealImplementation().getCorsAuthorizer();
1962                }
1963
1964                @NonNull
1965                @Override
1966                public McpSessionStore getSessionStore() {
1967                        return getRealImplementation().getSessionStore();
1968                }
1969
1970                @NonNull
1971                @Override
1972                public IdGenerator<String> getIdGenerator() {
1973                        return getRealImplementation().getIdGenerator();
1974                }
1975
1976                @NonNull
1977                protected McpServer getRealImplementation() {
1978                        return this.realImplementation;
1979                }
1980
1981                @NonNull
1982                protected Optional<SokletConfig> getSokletConfig() {
1983                        return Optional.ofNullable(this.sokletConfig);
1984                }
1985
1986                @NonNull
1987                protected Optional<RequestHandler> getRequestHandler() {
1988                        return Optional.ofNullable(this.requestHandler);
1989                }
1990
1991                protected void onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) {
1992                        this.mcpStreamErrorHandler.set(onMcpStreamError);
1993                }
1994
1995                @NonNull
1996                protected AtomicReference<Consumer<Throwable>> getMcpStreamErrorHandler() {
1997                        return this.mcpStreamErrorHandler;
1998                }
1999
2000                protected void registerOpenStream(@NonNull String sessionId,
2001                                                                                                                                                        @NonNull Request request,
2002                                                                                                                                                        McpRequestResult.StreamOpened streamOpened) {
2003                        requireNonNull(sessionId);
2004                        requireNonNull(request);
2005                        requireNonNull(streamOpened);
2006
2007                        getOpenStreamsBySessionId()
2008                                        .computeIfAbsent(sessionId, ignored -> new CopyOnWriteArrayList<>())
2009                                        .add(streamOpened);
2010
2011                        streamOpened.onClose(() -> closeOpenStream(sessionId, request, streamOpened));
2012                }
2013
2014                protected void terminateStreamsForSession(@NonNull String sessionId) {
2015                        requireNonNull(sessionId);
2016
2017                        CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().remove(sessionId);
2018
2019                        if (streams == null)
2020                                return;
2021
2022                        for (McpRequestResult.StreamOpened streamOpened : streams)
2023                                streamOpened.terminate();
2024                }
2025
2026                @NonNull
2027                @Override
2028                public Boolean publishSessionMessage(@NonNull String sessionId,
2029                                                                                                                                                                 @NonNull McpObject message) {
2030                        requireNonNull(sessionId);
2031                        requireNonNull(message);
2032
2033                        CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId);
2034
2035                        if (streams == null || streams.isEmpty())
2036                                return false;
2037
2038                        for (int i = streams.size() - 1; i >= 0; i--) {
2039                                McpRequestResult.StreamOpened streamOpened = streams.get(i);
2040
2041                                if (streamOpened.isClosed())
2042                                        continue;
2043
2044                                streamOpened.emitMessage(message);
2045                                return true;
2046                        }
2047
2048                        return false;
2049                }
2050
2051                protected void onClientDisconnectedMcpStream(@Nullable BiConsumer<Request, String> clientDisconnectedMcpStreamHandler) {
2052                        this.clientDisconnectedMcpStreamHandler.set(clientDisconnectedMcpStreamHandler);
2053                }
2054
2055                protected void closeOpenStream(@NonNull String sessionId,
2056                                                                                                                                         @NonNull Request request,
2057                                                                                                                                         McpRequestResult.StreamOpened streamOpened) {
2058                        requireNonNull(sessionId);
2059                        requireNonNull(request);
2060                        requireNonNull(streamOpened);
2061
2062                        CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId);
2063
2064                        if (streams != null) {
2065                                streams.remove(streamOpened);
2066
2067                                if (streams.isEmpty())
2068                                        getOpenStreamsBySessionId().remove(sessionId, streams);
2069                        }
2070
2071                        BiConsumer<Request, String> handler = this.clientDisconnectedMcpStreamHandler.get();
2072
2073                        if (handler != null)
2074                                handler.accept(request, sessionId);
2075                }
2076
2077                @NonNull
2078                protected ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> getOpenStreamsBySessionId() {
2079                        return this.openStreamsBySessionId;
2080                }
2081        }
2082
2083        /**
2084         * Mock Server-Sent Event unicaster that doesn't touch the network at all, useful for testing.
2085         */
2086        @ThreadSafe
2087        static class MockSseUnicaster implements SseUnicaster {
2088                @NonNull
2089                private final ResourcePath resourcePath;
2090                @NonNull
2091                private final Consumer<SseEvent> eventConsumer;
2092                @NonNull
2093                private final Consumer<SseComment> commentConsumer;
2094                @NonNull
2095                private final AtomicReference<Consumer<Throwable>> unicastErrorHandler;
2096
2097                public MockSseUnicaster(@NonNull ResourcePath resourcePath,
2098                                                                                                                                                                @NonNull Consumer<SseEvent> eventConsumer,
2099                                                                                                                                                                @NonNull Consumer<SseComment> commentConsumer,
2100                                                                                                                                                                @NonNull AtomicReference<Consumer<Throwable>> unicastErrorHandler) {
2101                        requireNonNull(resourcePath);
2102                        requireNonNull(eventConsumer);
2103                        requireNonNull(commentConsumer);
2104                        requireNonNull(unicastErrorHandler);
2105
2106                        this.resourcePath = resourcePath;
2107                        this.eventConsumer = eventConsumer;
2108                        this.commentConsumer = commentConsumer;
2109                        this.unicastErrorHandler = unicastErrorHandler;
2110                }
2111
2112                @Override
2113                public void unicastEvent(@NonNull SseEvent sseEvent) {
2114                        requireNonNull(sseEvent);
2115                        try {
2116                                getEventConsumer().accept(sseEvent);
2117                        } catch (Throwable throwable) {
2118                                handleUnicastError(throwable);
2119                        }
2120                }
2121
2122                @Override
2123                public void unicastComment(@NonNull SseComment sseComment) {
2124                        requireNonNull(sseComment);
2125                        try {
2126                                getCommentConsumer().accept(sseComment);
2127                        } catch (Throwable throwable) {
2128                                handleUnicastError(throwable);
2129                        }
2130                }
2131
2132                @NonNull
2133                @Override
2134                public ResourcePath getResourcePath() {
2135                        return this.resourcePath;
2136                }
2137
2138                @NonNull
2139                protected Consumer<SseEvent> getEventConsumer() {
2140                        return this.eventConsumer;
2141                }
2142
2143                @NonNull
2144                protected Consumer<SseComment> getCommentConsumer() {
2145                        return this.commentConsumer;
2146                }
2147
2148                protected void handleUnicastError(@NonNull Throwable throwable) {
2149                        requireNonNull(throwable);
2150                        Consumer<Throwable> handler = this.unicastErrorHandler.get();
2151
2152                        if (handler != null) {
2153                                try {
2154                                        handler.accept(throwable);
2155                                        return;
2156                                } catch (Throwable ignored) {
2157                                        // Fall through to default behavior
2158                                }
2159                        }
2160
2161                        throwable.printStackTrace();
2162                }
2163        }
2164
2165        /**
2166         * Mock Server-Sent Event broadcaster that doesn't touch the network at all, useful for testing.
2167         */
2168        @ThreadSafe
2169        static class MockSseBroadcaster implements SseBroadcaster {
2170                // ConcurrentHashMap doesn't allow null values, so we use a sentinel if context is null
2171                private static final Object NULL_CONTEXT_SENTINEL;
2172
2173                static {
2174                        NULL_CONTEXT_SENTINEL = new Object();
2175                }
2176
2177                @NonNull
2178                private final ResourcePath resourcePath;
2179                // Maps the Consumer (Listener) to its Context object (e.g. Locale)
2180                @NonNull
2181                private final Map<@NonNull Consumer<SseEvent>, @NonNull Object> eventConsumers;
2182                // Same goes for comments
2183                @NonNull
2184                private final Map<@NonNull Consumer<SseComment>, @NonNull Object> commentConsumers;
2185                @NonNull
2186                private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler;
2187
2188                public MockSseBroadcaster(@NonNull ResourcePath resourcePath,
2189                                                                                                                                                                        @NonNull AtomicReference<Consumer<Throwable>> broadcastErrorHandler) {
2190                        requireNonNull(resourcePath);
2191                        requireNonNull(broadcastErrorHandler);
2192
2193                        this.resourcePath = resourcePath;
2194                        this.eventConsumers = new ConcurrentHashMap<>();
2195                        this.commentConsumers = new ConcurrentHashMap<>();
2196                        this.broadcastErrorHandler = broadcastErrorHandler;
2197                }
2198
2199                @NonNull
2200                @Override
2201                public ResourcePath getResourcePath() {
2202                        return this.resourcePath;
2203                }
2204
2205                @NonNull
2206                @Override
2207                public Long getClientCount() {
2208                        return Long.valueOf(getEventConsumers().size() + getCommentConsumers().size());
2209                }
2210
2211                @Override
2212                public void broadcastEvent(@NonNull SseEvent sseEvent) {
2213                        requireNonNull(sseEvent);
2214
2215                        for (Consumer<SseEvent> eventConsumer : getEventConsumers().keySet()) {
2216                                try {
2217                                        eventConsumer.accept(sseEvent);
2218                                } catch (Throwable throwable) {
2219                                        handleBroadcastError(throwable);
2220                                }
2221                        }
2222                }
2223
2224                @Override
2225                public void broadcastComment(@NonNull SseComment sseComment) {
2226                        requireNonNull(sseComment);
2227
2228                        for (Consumer<SseComment> commentConsumer : getCommentConsumers().keySet()) {
2229                                try {
2230                                        commentConsumer.accept(sseComment);
2231                                } catch (Throwable throwable) {
2232                                        handleBroadcastError(throwable);
2233                                }
2234                        }
2235                }
2236
2237                @Override
2238                public <T> void broadcastEvent(
2239                                @NonNull Function<Object, T> keySelector,
2240                                @NonNull Function<T, SseEvent> eventProvider
2241                ) {
2242                        requireNonNull(keySelector);
2243                        requireNonNull(eventProvider);
2244
2245                        // 1. Create a temporary cache for this specific broadcast operation.
2246                        // This ensures we only run the expensive 'eventProvider' once per unique key.
2247                        Map<T, SseEvent> payloadCache = new HashMap<>();
2248
2249                        this.getEventConsumers().forEach((consumer, context) -> {
2250                                try {
2251                                        // 2. Derive the key from the subscriber's context
2252                                        T key = keySelector.apply(context);
2253
2254                                        // 3. Memoize: Generate the payload if we haven't seen this key yet, otherwise reuse it
2255                                        SseEvent event = payloadCache.computeIfAbsent(key, eventProvider);
2256
2257                                        // 4. Dispatch
2258                                        consumer.accept(event);
2259                                } catch (Throwable throwable) {
2260                                        handleBroadcastError(throwable);
2261                                }
2262                        });
2263                }
2264
2265                @Override
2266                public <T> void broadcastComment(
2267                                @NonNull Function<Object, T> keySelector,
2268                                @NonNull Function<T, SseComment> commentProvider
2269                ) {
2270                        requireNonNull(keySelector);
2271                        requireNonNull(commentProvider);
2272
2273                        // 1. Create temporary cache
2274                        Map<T, SseComment> commentCache = new HashMap<>();
2275
2276                        this.getCommentConsumers().forEach((consumer, context) -> {
2277                                try {
2278                                        // 2. Derive key
2279                                        T key = keySelector.apply(context);
2280
2281                                        // 3. Memoize
2282                                        SseComment comment = commentCache.computeIfAbsent(key, commentProvider);
2283
2284                                        // 4. Dispatch
2285                                        consumer.accept(comment);
2286                                } catch (Throwable throwable) {
2287                                        handleBroadcastError(throwable);
2288                                }
2289                        });
2290                }
2291
2292                @NonNull
2293                public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) {
2294                        return registerEventConsumer(eventConsumer, null);
2295                }
2296
2297                /**
2298                 * Registers a consumer with an associated context, simulating a client with specific traits.
2299                 */
2300                @NonNull
2301                public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer, @Nullable Object context) {
2302                        requireNonNull(eventConsumer);
2303                        // map.put returns null if the key was new, which conceptually matches "add" returning true
2304                        return this.getEventConsumers().put(eventConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null;
2305                }
2306
2307                @NonNull
2308                public Boolean unregisterEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) {
2309                        requireNonNull(eventConsumer);
2310                        return this.getEventConsumers().remove(eventConsumer) != null;
2311                }
2312
2313                @NonNull
2314                public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) {
2315                        return registerCommentConsumer(commentConsumer, null);
2316                }
2317
2318                /**
2319                 * Registers a consumer with an associated context, simulating a client with specific traits.
2320                 */
2321                @NonNull
2322                public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer, @Nullable Object context) {
2323                        requireNonNull(commentConsumer);
2324                        return this.getCommentConsumers().put(commentConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null;
2325                }
2326
2327                @NonNull
2328                public Boolean unregisterCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) {
2329                        requireNonNull(commentConsumer);
2330                        return this.getCommentConsumers().remove(commentConsumer) != null;
2331                }
2332
2333                @NonNull
2334                protected Map<@NonNull Consumer<SseEvent>, @NonNull Object> getEventConsumers() {
2335                        return this.eventConsumers;
2336                }
2337
2338                @NonNull
2339                protected Map<@NonNull Consumer<SseComment>, @NonNull Object> getCommentConsumers() {
2340                        return this.commentConsumers;
2341                }
2342
2343                protected void handleBroadcastError(@NonNull Throwable throwable) {
2344                        requireNonNull(throwable);
2345                        Consumer<Throwable> handler = this.broadcastErrorHandler.get();
2346
2347                        if (handler != null) {
2348                                try {
2349                                        handler.accept(throwable);
2350                                        return;
2351                                } catch (Throwable ignored) {
2352                                        // Fall through to default behavior
2353                                }
2354                        }
2355
2356                        throwable.printStackTrace();
2357                }
2358        }
2359
2360        /**
2361         * Mock Server-Sent Event server that doesn't touch the network at all, useful for testing.
2362         *
2363         * @author <a href="https://www.revetkn.com">Mark Allen</a>
2364         */
2365        @ThreadSafe
2366        static class MockSseServer implements SseServer {
2367                @Nullable
2368                private SokletConfig sokletConfig;
2369                private SseServer.@Nullable RequestHandler requestHandler;
2370                @NonNull
2371                private final ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> broadcastersByResourcePath;
2372                @NonNull
2373                private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler;
2374                @NonNull
2375                private final AtomicReference<Consumer<Throwable>> unicastErrorHandler;
2376
2377                public MockSseServer() {
2378                        this.broadcastersByResourcePath = new ConcurrentHashMap<>();
2379                        this.broadcastErrorHandler = new AtomicReference<>();
2380                        this.unicastErrorHandler = new AtomicReference<>();
2381                }
2382
2383                @Override
2384                public void start() {
2385                        // No-op
2386                }
2387
2388                @Override
2389                public void stop() {
2390                        // No-op
2391                }
2392
2393                @NonNull
2394                @Override
2395                public Boolean isStarted() {
2396                        return true;
2397                }
2398
2399                @NonNull
2400                @Override
2401                public Optional<? extends SseBroadcaster> acquireBroadcaster(@Nullable ResourcePath resourcePath) {
2402                        if (resourcePath == null)
2403                                return Optional.empty();
2404
2405                        MockSseBroadcaster broadcaster = getBroadcastersByResourcePath()
2406                                        .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler));
2407
2408                        return Optional.of(broadcaster);
2409                }
2410
2411                public void registerEventConsumer(@NonNull ResourcePath resourcePath,
2412                                                                                                                                                        @NonNull Consumer<SseEvent> eventConsumer) {
2413                        registerEventConsumer(resourcePath, eventConsumer, null);
2414                }
2415
2416                public void registerEventConsumer(@NonNull ResourcePath resourcePath,
2417                                                                                                                                                        @NonNull Consumer<SseEvent> eventConsumer,
2418                                                                                                                                                        @Nullable Object context) {
2419                        requireNonNull(resourcePath);
2420                        requireNonNull(eventConsumer);
2421
2422                        MockSseBroadcaster broadcaster = getBroadcastersByResourcePath()
2423                                        .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler));
2424
2425                        broadcaster.registerEventConsumer(eventConsumer, context);
2426                }
2427
2428                @NonNull
2429                public Boolean unregisterEventConsumer(@NonNull ResourcePath resourcePath,
2430                                                                                                                                                                         @NonNull Consumer<SseEvent> eventConsumer) {
2431                        requireNonNull(resourcePath);
2432                        requireNonNull(eventConsumer);
2433
2434                        MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath);
2435
2436                        if (broadcaster == null)
2437                                return false;
2438
2439                        return broadcaster.unregisterEventConsumer(eventConsumer);
2440                }
2441
2442                public void registerCommentConsumer(@NonNull ResourcePath resourcePath,
2443                                                                                                                                                                @NonNull Consumer<SseComment> commentConsumer) {
2444                        registerCommentConsumer(resourcePath, commentConsumer, null);
2445                }
2446
2447                public void registerCommentConsumer(@NonNull ResourcePath resourcePath,
2448                                                                                                                                                                @NonNull Consumer<SseComment> commentConsumer,
2449                                                                                                                                                                @Nullable Object context) {
2450                        requireNonNull(resourcePath);
2451                        requireNonNull(commentConsumer);
2452
2453                        MockSseBroadcaster broadcaster = getBroadcastersByResourcePath()
2454                                        .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler));
2455
2456                        broadcaster.registerCommentConsumer(commentConsumer, context);
2457                }
2458
2459                @NonNull
2460                public Boolean unregisterCommentConsumer(@NonNull ResourcePath resourcePath,
2461                                                                                                                                                                                 @NonNull Consumer<SseComment> commentConsumer) {
2462                        requireNonNull(resourcePath);
2463                        requireNonNull(commentConsumer);
2464
2465                        MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath);
2466
2467                        if (broadcaster == null)
2468                                return false;
2469
2470                        return broadcaster.unregisterCommentConsumer(commentConsumer);
2471                }
2472
2473                @Override
2474                public void initialize(@NonNull SokletConfig sokletConfig,
2475                                                                                                         SseServer.@NonNull RequestHandler requestHandler) {
2476                        requireNonNull(sokletConfig);
2477                        requireNonNull(requestHandler);
2478
2479                        this.sokletConfig = sokletConfig;
2480                        this.requestHandler = requestHandler;
2481                }
2482
2483                public void onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) {
2484                        this.broadcastErrorHandler.set(onBroadcastError);
2485                }
2486
2487                public void onUnicastError(@Nullable Consumer<Throwable> onUnicastError) {
2488                        this.unicastErrorHandler.set(onUnicastError);
2489                }
2490
2491                @NonNull
2492                protected Optional<SokletConfig> getSokletConfig() {
2493                        return Optional.ofNullable(this.sokletConfig);
2494                }
2495
2496                @NonNull
2497                protected Optional<SseServer.RequestHandler> getRequestHandler() {
2498                        return Optional.ofNullable(this.requestHandler);
2499                }
2500
2501                @NonNull
2502                protected ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> getBroadcastersByResourcePath() {
2503                        return this.broadcastersByResourcePath;
2504                }
2505
2506                @NonNull
2507                protected AtomicReference<Consumer<Throwable>> getUnicastErrorHandler() {
2508                        return this.unicastErrorHandler;
2509                }
2510        }
2511
2512        /**
2513         * Efficiently provides the current time in RFC 1123 HTTP-date format.
2514         * Updates once per second to avoid the overhead of formatting dates on every request.
2515         */
2516        @ThreadSafe
2517        private static final class CachedHttpDate {
2518                @NonNull
2519                private static final DateTimeFormatter FORMATTER = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss z", Locale.ENGLISH)
2520                                .withZone(ZoneId.of("GMT"));
2521                @NonNull
2522                private static volatile String CURRENT_VALUE = FORMATTER.format(Instant.now());
2523
2524                static {
2525                        Thread t = new Thread(() -> {
2526                                while (true) {
2527                                        try {
2528                                                Thread.sleep(1000);
2529                                                CURRENT_VALUE = FORMATTER.format(Instant.now());
2530                                        } catch (InterruptedException e) {
2531                                                break; // Allow thread to die on JVM shutdown
2532                                        }
2533                                }
2534                        }, "soklet-date-header-value-updater");
2535                        t.setDaemon(true);
2536                        t.start();
2537                }
2538
2539                @NonNull
2540                static String getCurrentValue() {
2541                        return CURRENT_VALUE;
2542                }
2543        }
2544}