001/* 002 * Copyright 2022-2026 Revetware LLC. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.soklet; 018 019import com.soklet.SseRequestResult.HandshakeAccepted; 020import com.soklet.SseRequestResult.HandshakeRejected; 021import com.soklet.annotation.SseEventSource; 022import com.soklet.internal.spring.LinkedCaseInsensitiveMap; 023import org.jspecify.annotations.NonNull; 024import org.jspecify.annotations.Nullable; 025 026import javax.annotation.concurrent.ThreadSafe; 027import java.io.BufferedReader; 028import java.io.IOException; 029import java.io.InputStreamReader; 030import java.lang.reflect.InvocationTargetException; 031import java.nio.charset.Charset; 032import java.nio.charset.StandardCharsets; 033import java.time.Duration; 034import java.time.Instant; 035import java.time.ZoneId; 036import java.time.format.DateTimeFormatter; 037import java.util.ArrayList; 038import java.util.Collections; 039import java.util.EnumSet; 040import java.util.HashMap; 041import java.util.LinkedHashMap; 042import java.util.List; 043import java.util.Locale; 044import java.util.Map; 045import java.util.Map.Entry; 046import java.util.Objects; 047import java.util.Optional; 048import java.util.Set; 049import java.util.concurrent.ConcurrentHashMap; 050import java.util.concurrent.CopyOnWriteArrayList; 051import java.util.concurrent.CopyOnWriteArraySet; 052import java.util.concurrent.CountDownLatch; 053import java.util.concurrent.atomic.AtomicBoolean; 054import java.util.concurrent.atomic.AtomicReference; 055import java.util.concurrent.locks.ReentrantLock; 056import java.util.function.BiConsumer; 057import java.util.function.Consumer; 058import java.util.function.Function; 059import java.util.stream.Collectors; 060 061import static com.soklet.Utilities.emptyByteArray; 062import static com.soklet.Utilities.extractContentTypeFromHeaders; 063import static java.lang.String.format; 064import static java.util.Objects.requireNonNull; 065 066/** 067 * Soklet's main class - manages one or more configured transport servers ({@link HttpServer}, {@link SseServer}, and/or {@link McpServer}) 068 * using the provided system configuration. 069 * <p> 070 * <pre>{@code // Use out-of-the-box defaults 071 * SokletConfig config = SokletConfig.withHttpServer( 072 * HttpServer.fromPort(8080) 073 * ).build(); 074 * 075 * try (Soklet soklet = Soklet.fromConfig(config)) { 076 * soklet.start(); 077 * System.out.println("Soklet started, press [enter] to exit"); 078 * soklet.awaitShutdown(ShutdownTrigger.ENTER_KEY); 079 * }}</pre> 080 * <p> 081 * Soklet also offers an off-network {@link Simulator} concept via {@link #runSimulator(SokletConfig, Consumer)}, useful for integration testing. 082 * <p> 083 * Given a <em>Resource Method</em>... 084 * <pre>{@code public class HelloResource { 085 * @GET("/hello") 086 * public String hello(@QueryParameter String name) { 087 * return String.format("Hello, %s", name); 088 * } 089 * }}</pre> 090 * ...we might test it like this: 091 * <pre>{@code @Test 092 * public void integrationTest() { 093 * // Just use your app's existing configuration 094 * SokletConfig config = obtainMySokletConfig(); 095 * 096 * // Instead of running on a real HTTP server that listens on a port, 097 * // a non-network Simulator is provided against which you can 098 * // issue requests and receive responses. 099 * Soklet.runSimulator(config, (simulator -> { 100 * // Construct a request 101 * Request request = Request.withPath(HttpMethod.GET, "/hello") 102 * .queryParameters(Map.of("name", Set.of("Mark"))) 103 * .build(); 104 * 105 * // Perform the request and get a handle to the response 106 * HttpRequestResult result = simulator.performHttpRequest(request); 107 * MarshaledResponse marshaledResponse = result.getMarshaledResponse(); 108 * 109 * // Verify status code 110 * Integer expectedCode = 200; 111 * Integer actualCode = marshaledResponse.getStatusCode(); 112 * assertEquals(expectedCode, actualCode, "Bad status code"); 113 * 114 * // Verify response body 115 * marshaledResponse.getBody().ifPresentOrElse(body -> { 116 * String expectedBody = "Hello, Mark"; 117 * byte[] bytes = ((MarshaledResponseBody.Bytes) body).getBytes(); 118 * String actualBody = new String(bytes, StandardCharsets.UTF_8); 119 * assertEquals(expectedBody, actualBody, "Bad response body"); 120 * }, () -> { 121 * Assertions.fail("No response body"); 122 * }); 123 * })); 124 * }}</pre> 125 * <p> 126 * The {@link Simulator} also supports Server-Sent Events. 127 * <p> 128 * Integration testing documentation is available at <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a>. 129 * 130 * @author <a href="https://www.revetkn.com">Mark Allen</a> 131 */ 132@ThreadSafe 133public final class Soklet implements AutoCloseable { 134 @NonNull 135 private static final Map<@NonNull String, @NonNull Set<@NonNull String>> DEFAULT_ACCEPTED_HANDSHAKE_HEADERS; 136 137 static { 138 // Generally speaking, we always want these headers for SSE streaming responses. 139 // Users can override if they think necessary 140 LinkedCaseInsensitiveMap<Set<String>> defaultAcceptedHandshakeHeaders = new LinkedCaseInsensitiveMap<>(4); 141 defaultAcceptedHandshakeHeaders.put("Content-Type", Set.of("text/event-stream; charset=UTF-8")); 142 defaultAcceptedHandshakeHeaders.put("Cache-Control", Set.of("no-cache", "no-transform")); 143 defaultAcceptedHandshakeHeaders.put("Connection", Set.of("keep-alive")); 144 defaultAcceptedHandshakeHeaders.put("X-Accel-Buffering", Set.of("no")); 145 146 DEFAULT_ACCEPTED_HANDSHAKE_HEADERS = Collections.unmodifiableMap(defaultAcceptedHandshakeHeaders); 147 } 148 149 /** 150 * Acquires a Soklet instance with the given configuration. 151 * 152 * @param sokletConfig configuration that drives the Soklet system 153 * @return a Soklet instance 154 */ 155 @NonNull 156 public static Soklet fromConfig(@NonNull SokletConfig sokletConfig) { 157 requireNonNull(sokletConfig); 158 return new Soklet(sokletConfig); 159 } 160 161 @NonNull 162 private final SokletConfig sokletConfig; 163 @NonNull 164 private final ReentrantLock lock; 165 @NonNull 166 private final AtomicReference<CountDownLatch> awaitShutdownLatchReference; 167 @NonNull 168 private final DefaultMcpRuntime defaultMcpRuntime; 169 170 /** 171 * Creates a Soklet instance with the given configuration. 172 * 173 * @param sokletConfig configuration that drives the Soklet system 174 */ 175 private Soklet(@NonNull SokletConfig sokletConfig) { 176 requireNonNull(sokletConfig); 177 178 this.sokletConfig = sokletConfig; 179 this.lock = new ReentrantLock(); 180 this.awaitShutdownLatchReference = new AtomicReference<>(new CountDownLatch(1)); 181 this.defaultMcpRuntime = new DefaultMcpRuntime(this); 182 183 sokletConfig.getMcpServer() 184 .map(McpServer::getSessionStore) 185 .filter(DefaultMcpSessionStore.class::isInstance) 186 .map(DefaultMcpSessionStore.class::cast) 187 .ifPresent(sessionStore -> sessionStore.pinnedSessionPredicate(this.defaultMcpRuntime::hasActiveStream)); 188 189 sokletConfig.getMcpServer() 190 .map(mcpServer -> mcpServer instanceof McpServerProxy mcpServerProxy ? mcpServerProxy.getRealImplementation() : mcpServer) 191 .filter(DefaultMcpServer.class::isInstance) 192 .map(DefaultMcpServer.class::cast) 193 .ifPresent(defaultMcpServer -> defaultMcpServer.mcpRuntime(this.defaultMcpRuntime)); 194 195 Set<ResourceMethod> resourceMethods = sokletConfig.getResourceMethodResolver().getResourceMethods(); 196 197 // Fail fast in the event that Soklet appears misconfigured 198 if (resourceMethods.size() == 0 199 && sokletConfig.getMcpServer().isEmpty()) 200 throw new IllegalStateException(format("No Soklet Resource Methods were found. First, try to rebuild and see if that solves the problem. If not, please ensure your %s is configured correctly. " 201 + "See https://www.soklet.com/docs/request-handling#resource-method-resolution for details.", ResourceMethodResolver.class.getSimpleName())); 202 203 boolean hasStandardHttpResourceMethods = resourceMethods.stream() 204 .anyMatch(resourceMethod -> !resourceMethod.isSseEventSource()); 205 206 if (hasStandardHttpResourceMethods && sokletConfig.getHttpServer().isEmpty()) 207 throw new IllegalStateException(format("Resource Methods were found, but no %s is configured. See https://www.soklet.com/docs/server-configuration for details.", 208 HttpServer.class.getSimpleName())); 209 210 // SSE misconfiguration check: @SseEventSource resource methods are declared, but not SseServer exists 211 boolean hasSseResourceMethods = resourceMethods.stream() 212 .anyMatch(resourceMethod -> resourceMethod.isSseEventSource()); 213 214 if (hasSseResourceMethods && sokletConfig.getSseServer().isEmpty()) 215 throw new IllegalStateException(format("Resource Methods annotated with @%s were found, but no %s is configured. See https://www.soklet.com/docs/server-sent-events for details.", 216 SseEventSource.class.getSimpleName(), SseServer.class.getSimpleName())); 217 218 MetricsCollector metricsCollector = sokletConfig.getMetricsCollector(); 219 220 if (metricsCollector instanceof DefaultMetricsCollector defaultMetricsCollector) { 221 try { 222 defaultMetricsCollector.initialize(sokletConfig); 223 } catch (Throwable t) { 224 sokletConfig.getLifecycleObserver().didReceiveLogEvent( 225 LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED, 226 format("An exception occurred while initializing %s", metricsCollector.getClass().getSimpleName())) 227 .throwable(t) 228 .build()); 229 } 230 } 231 232 // Use a layer of indirection here so the Soklet type does not need to directly implement the `RequestHandler` interface. 233 // Reasoning: the `handleRequest` method for Soklet should not be public, which might lead to accidental invocation by users. 234 // That method should only be called by the managed `HttpServer` instance. 235 Soklet soklet = this; 236 237 sokletConfig.getHttpServer().ifPresent(server -> server.initialize(getSokletConfig(), (request, marshaledResponseConsumer) -> { 238 // Delegate to Soklet's internal request handling method 239 soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer); 240 })); 241 242 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 243 244 if (sseServer != null) 245 sseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 246 // Delegate to Soklet's internal request handling method 247 soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer); 248 }); 249 250 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 251 252 if (mcpServer != null) 253 mcpServer.initialize(sokletConfig, soklet::handleMcpRequest); 254 } 255 256 /** 257 * Starts the managed server instance[s]. 258 * <p> 259 * If the managed server[s] are already started, this is a no-op. 260 */ 261 public void start() { 262 getLock().lock(); 263 264 try { 265 if (isStarted()) 266 return; 267 268 getAwaitShutdownLatchReference().set(new CountDownLatch(1)); 269 270 SokletConfig sokletConfig = getSokletConfig(); 271 LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver(); 272 273 // 1. Notify global intent to start 274 lifecycleObserver.willStartSoklet(this); 275 276 HttpServer httpServer = sokletConfig.getHttpServer().orElse(null); 277 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 278 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 279 boolean httpServerStarted = false; 280 boolean sseServerStarted = false; 281 boolean mcpServerStarted = false; 282 283 try { 284 // 2. Attempt to start Main HttpServer 285 if (httpServer != null) { 286 lifecycleObserver.willStartHttpServer(httpServer); 287 try { 288 httpServer.start(); 289 httpServerStarted = true; 290 lifecycleObserver.didStartHttpServer(httpServer); 291 } catch (Throwable t) { 292 lifecycleObserver.didFailToStartHttpServer(httpServer, t); 293 throw t; // Rethrow to trigger outer catch block 294 } 295 } 296 297 // 3. Attempt to start SSE HttpServer (if present) 298 if (sseServer != null) { 299 lifecycleObserver.willStartSseServer(sseServer); 300 try { 301 sseServer.start(); 302 sseServerStarted = true; 303 lifecycleObserver.didStartSseServer(sseServer); 304 } catch (Throwable t) { 305 lifecycleObserver.didFailToStartSseServer(sseServer, t); 306 throw t; // Rethrow to trigger outer catch block 307 } 308 } 309 310 if (mcpServer != null) { 311 lifecycleObserver.willStartMcpServer(mcpServer); 312 try { 313 mcpServer.start(); 314 mcpServerStarted = true; 315 lifecycleObserver.didStartMcpServer(mcpServer); 316 } catch (Throwable t) { 317 lifecycleObserver.didFailToStartMcpServer(mcpServer, t); 318 throw t; 319 } 320 } 321 322 // 4. Global success 323 lifecycleObserver.didStartSoklet(this); 324 } catch (Throwable t) { 325 rollbackStartedServersAfterFailedStart(lifecycleObserver, 326 httpServer, httpServerStarted, 327 sseServer, sseServerStarted, 328 mcpServer, mcpServerStarted, 329 t); 330 331 // 5. Global failure 332 lifecycleObserver.didFailToStartSoklet(this, t); 333 334 // Ensure the exception bubbles up so the application knows startup failed 335 if (t instanceof RuntimeException) 336 throw (RuntimeException) t; 337 338 throw new RuntimeException(t); 339 } 340 } finally { 341 getLock().unlock(); 342 } 343 } 344 345 private void rollbackStartedServersAfterFailedStart(@NonNull LifecycleObserver lifecycleObserver, 346 @Nullable HttpServer httpServer, 347 boolean httpServerStarted, 348 @Nullable SseServer sseServer, 349 boolean sseServerStarted, 350 @Nullable McpServer mcpServer, 351 boolean mcpServerStarted, 352 @NonNull Throwable startupFailure) { 353 requireNonNull(lifecycleObserver); 354 requireNonNull(startupFailure); 355 356 if (mcpServerStarted && mcpServer != null) 357 stopStartedMcpServerForRollback(lifecycleObserver, mcpServer, startupFailure); 358 359 if (sseServerStarted && sseServer != null) 360 stopStartedSseServerForRollback(lifecycleObserver, sseServer, startupFailure); 361 362 if (httpServerStarted && httpServer != null) 363 stopStartedHttpServerForRollback(lifecycleObserver, httpServer, startupFailure); 364 365 CountDownLatch awaitShutdownLatch = getAwaitShutdownLatchReference().get(); 366 367 if (awaitShutdownLatch != null) 368 awaitShutdownLatch.countDown(); 369 } 370 371 private void stopStartedHttpServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 372 @NonNull HttpServer httpServer, 373 @NonNull Throwable startupFailure) { 374 requireNonNull(lifecycleObserver); 375 requireNonNull(httpServer); 376 requireNonNull(startupFailure); 377 378 try { 379 lifecycleObserver.willStopHttpServer(httpServer); 380 } catch (Throwable t) { 381 startupFailure.addSuppressed(t); 382 } 383 384 try { 385 httpServer.stop(); 386 try { 387 lifecycleObserver.didStopHttpServer(httpServer); 388 } catch (Throwable t) { 389 startupFailure.addSuppressed(t); 390 } 391 } catch (Throwable t) { 392 startupFailure.addSuppressed(t); 393 394 try { 395 lifecycleObserver.didFailToStopHttpServer(httpServer, t); 396 } catch (Throwable t2) { 397 startupFailure.addSuppressed(t2); 398 } 399 } 400 } 401 402 private void stopStartedSseServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 403 @NonNull SseServer sseServer, 404 @NonNull Throwable startupFailure) { 405 requireNonNull(lifecycleObserver); 406 requireNonNull(sseServer); 407 requireNonNull(startupFailure); 408 409 try { 410 lifecycleObserver.willStopSseServer(sseServer); 411 } catch (Throwable t) { 412 startupFailure.addSuppressed(t); 413 } 414 415 try { 416 sseServer.stop(); 417 try { 418 lifecycleObserver.didStopSseServer(sseServer); 419 } catch (Throwable t) { 420 startupFailure.addSuppressed(t); 421 } 422 } catch (Throwable t) { 423 startupFailure.addSuppressed(t); 424 425 try { 426 lifecycleObserver.didFailToStopSseServer(sseServer, t); 427 } catch (Throwable t2) { 428 startupFailure.addSuppressed(t2); 429 } 430 } 431 } 432 433 private void stopStartedMcpServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 434 @NonNull McpServer mcpServer, 435 @NonNull Throwable startupFailure) { 436 requireNonNull(lifecycleObserver); 437 requireNonNull(mcpServer); 438 requireNonNull(startupFailure); 439 440 try { 441 lifecycleObserver.willStopMcpServer(mcpServer); 442 } catch (Throwable t) { 443 startupFailure.addSuppressed(t); 444 } 445 446 try { 447 mcpServer.stop(); 448 try { 449 lifecycleObserver.didStopMcpServer(mcpServer); 450 } catch (Throwable t) { 451 startupFailure.addSuppressed(t); 452 } 453 } catch (Throwable t) { 454 startupFailure.addSuppressed(t); 455 456 try { 457 lifecycleObserver.didFailToStopMcpServer(mcpServer, t); 458 } catch (Throwable t2) { 459 startupFailure.addSuppressed(t2); 460 } 461 } 462 } 463 464 /** 465 * Stops the managed server instance[s]. 466 * <p> 467 * If the managed server[s] are already stopped, this is a no-op. 468 */ 469 public void stop() { 470 getLock().lock(); 471 472 try { 473 if (isStarted()) { 474 SokletConfig sokletConfig = getSokletConfig(); 475 LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver(); 476 477 // 1. Notify global intent to stop 478 lifecycleObserver.willStopSoklet(this); 479 480 Throwable firstEncounteredException = null; 481 HttpServer httpServer = sokletConfig.getHttpServer().orElse(null); 482 483 // 2. Attempt to stop Main HttpServer 484 if (httpServer != null && httpServer.isStarted()) { 485 lifecycleObserver.willStopHttpServer(httpServer); 486 try { 487 httpServer.stop(); 488 lifecycleObserver.didStopHttpServer(httpServer); 489 } catch (Throwable t) { 490 firstEncounteredException = t; 491 lifecycleObserver.didFailToStopHttpServer(httpServer, t); 492 } 493 } 494 495 // 3. Attempt to stop SSE HttpServer 496 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 497 498 if (sseServer != null && sseServer.isStarted()) { 499 lifecycleObserver.willStopSseServer(sseServer); 500 try { 501 sseServer.stop(); 502 lifecycleObserver.didStopSseServer(sseServer); 503 } catch (Throwable t) { 504 if (firstEncounteredException == null) 505 firstEncounteredException = t; 506 507 lifecycleObserver.didFailToStopSseServer(sseServer, t); 508 } 509 } 510 511 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 512 513 if (mcpServer != null && mcpServer.isStarted()) { 514 lifecycleObserver.willStopMcpServer(mcpServer); 515 try { 516 mcpServer.stop(); 517 lifecycleObserver.didStopMcpServer(mcpServer); 518 } catch (Throwable t) { 519 if (firstEncounteredException == null) 520 firstEncounteredException = t; 521 522 lifecycleObserver.didFailToStopMcpServer(mcpServer, t); 523 } 524 } 525 526 // 4. Global completion (Success or Failure) 527 if (firstEncounteredException == null) 528 lifecycleObserver.didStopSoklet(this); 529 else 530 lifecycleObserver.didFailToStopSoklet(this, firstEncounteredException); 531 } 532 } finally { 533 try { 534 getAwaitShutdownLatchReference().get().countDown(); 535 } finally { 536 getLock().unlock(); 537 } 538 } 539 } 540 541 /** 542 * Blocks the current thread until JVM shutdown ({@code SIGTERM/SIGINT/System.exit(...)} and so forth), <strong>or</strong> if one of the provided {@code shutdownTriggers} occurs. 543 * <p> 544 * This method will automatically invoke this instance's {@link #stop()} method once it becomes unblocked. 545 * <p> 546 * <strong>Notes regarding {@link ShutdownTrigger#ENTER_KEY}:</strong> 547 * <ul> 548 * <li>It will invoke {@link #stop()} on <i>all</i> Soklet instances, as stdin is process-wide</li> 549 * <li>It is only supported for environments with an interactive TTY and will be ignored if none exists (e.g. running in a Docker container) - Soklet will detect this and fire {@link LifecycleObserver#didReceiveLogEvent(LogEvent)} with an event of type {@link LogEventType#CONFIGURATION_UNSUPPORTED}</li> 550 * </ul> 551 * 552 * @param shutdownTriggers additional trigger[s] which signal that shutdown should occur, e.g. {@link ShutdownTrigger#ENTER_KEY} for "enter key pressed" 553 * @throws InterruptedException if the current thread has its interrupted status set on entry to this method, or is interrupted while waiting 554 */ 555 public void awaitShutdown(@Nullable ShutdownTrigger... shutdownTriggers) throws InterruptedException { 556 Thread shutdownHook = null; 557 boolean registeredEnterKeyShutdownTrigger = false; 558 Set<ShutdownTrigger> shutdownTriggersAsSet = shutdownTriggers == null || shutdownTriggers.length == 0 ? Set.of() : EnumSet.copyOf(Set.of(shutdownTriggers)); 559 560 try { 561 // Optionally listen for enter key 562 if (shutdownTriggersAsSet.contains(ShutdownTrigger.ENTER_KEY)) { 563 registeredEnterKeyShutdownTrigger = KeypressManager.register(this); // returns false if stdin unusable/disabled 564 565 if (!registeredEnterKeyShutdownTrigger) { 566 LogEvent logEvent = LogEvent.with( 567 LogEventType.CONFIGURATION_UNSUPPORTED, 568 format("Ignoring request for %s.%s - it is unsupported in this environment (no interactive TTY detected)", ShutdownTrigger.class.getSimpleName(), ShutdownTrigger.ENTER_KEY.name()) 569 ).build(); 570 571 getSokletConfig().getLifecycleObserver().didReceiveLogEvent(logEvent); 572 } 573 } 574 575 // Always register a shutdown hook 576 shutdownHook = new Thread(() -> { 577 try { 578 stop(); 579 } catch (Throwable ignored) { 580 // Nothing to do 581 } 582 }, "soklet-shutdown-hook"); 583 584 Runtime.getRuntime().addShutdownHook(shutdownHook); 585 586 // Wait until "close" finishes 587 getAwaitShutdownLatchReference().get().await(); 588 } finally { 589 if (registeredEnterKeyShutdownTrigger) 590 KeypressManager.unregister(this); 591 592 try { 593 Runtime.getRuntime().removeShutdownHook(shutdownHook); 594 } catch (IllegalStateException ignored) { 595 // JVM shutting down 596 } 597 } 598 } 599 600 /** 601 * Handles "awaitShutdown" for {@link ShutdownTrigger#ENTER_KEY} by listening to stdin - all Soklet instances are terminated on keypress. 602 */ 603 @ThreadSafe 604 private static final class KeypressManager { 605 @NonNull 606 private static final Set<@NonNull Soklet> SOKLET_REGISTRY; 607 @NonNull 608 private static final AtomicBoolean LISTENER_STARTED; 609 610 static { 611 SOKLET_REGISTRY = new CopyOnWriteArraySet<>(); 612 LISTENER_STARTED = new AtomicBoolean(false); 613 } 614 615 /** 616 * Register a Soklet for Enter-to-stop support. Returns true iff a listener is (or was already) active. 617 * If System.in is not usable (or disabled), returns false and does nothing. 618 */ 619 @NonNull 620 synchronized static Boolean register(@NonNull Soklet soklet) { 621 requireNonNull(soklet); 622 623 // If stdin is not readable (e.g., container with no TTY), don't start a listener. 624 if (!canReadFromStdin()) 625 return false; 626 627 SOKLET_REGISTRY.add(soklet); 628 629 // Start a single process-wide listener once. 630 if (LISTENER_STARTED.compareAndSet(false, true)) { 631 Thread thread = new Thread(KeypressManager::runLoop, "soklet-keypress-shutdown-listener"); 632 thread.setDaemon(true); // never block JVM exit 633 thread.start(); 634 } 635 636 return true; 637 } 638 639 synchronized static void unregister(@NonNull Soklet soklet) { 640 SOKLET_REGISTRY.remove(soklet); 641 // We intentionally keep the listener alive; it's daemon and cheap. 642 // If stdin hits EOF, the listener exits on its own. 643 } 644 645 /** 646 * Heuristic: if System.in is present and calling available() doesn't throw, 647 * treat it as readable. Works even in IDEs where System.console() is null. 648 */ 649 @NonNull 650 private static Boolean canReadFromStdin() { 651 if (System.in == null) 652 return false; 653 654 try { 655 // available() >= 0 means stream is open; 0 means no buffered data (that’s fine). 656 return System.in.available() >= 0; 657 } catch (IOException e) { 658 return false; 659 } 660 } 661 662 /** 663 * Single blocking read on stdin. On any line (or EOF), stop all registered servers. 664 */ 665 private static void runLoop() { 666 try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8))) { 667 // Blocks until newline or EOF; EOF (null) happens with /dev/null or closed pipe. 668 bufferedReader.readLine(); 669 stopAllSoklets(); 670 } catch (Throwable ignored) { 671 // If stdin is closed mid-run, just exit quietly. 672 } 673 } 674 675 synchronized private static void stopAllSoklets() { 676 // Either a line or EOF → stop everything that’s currently registered. 677 for (Soklet soklet : SOKLET_REGISTRY) { 678 try { 679 soklet.stop(); 680 } catch (Throwable ignored) { 681 // Nothing to do 682 } 683 } 684 } 685 686 private KeypressManager() {} 687 } 688 689 /** 690 * Nonpublic "informal" implementation of {@link com.soklet.HttpServer.RequestHandler} so Soklet does not need to expose {@code handleRequest} publicly. 691 * Reasoning: users of this library should never call {@code handleRequest} directly - it should only be invoked in response to events 692 * provided by a {@link HttpServer} or {@link SseServer} implementation. 693 */ 694 protected void handleRequest(@NonNull Request request, 695 @NonNull ServerType serverType, 696 @NonNull Consumer<HttpRequestResult> requestResultConsumer) { 697 requireNonNull(request); 698 requireNonNull(serverType); 699 requireNonNull(requestResultConsumer); 700 701 Instant processingStarted = Instant.now(); 702 703 SokletConfig sokletConfig = getSokletConfig(); 704 ResourceMethodResolver resourceMethodResolver = sokletConfig.getResourceMethodResolver(); 705 ResponseMarshaler responseMarshaler = sokletConfig.getResponseMarshaler(); 706 LifecycleObserver lifecycleObserver = sokletConfig.getLifecycleObserver(); 707 RequestInterceptor requestInterceptor = sokletConfig.getRequestInterceptor(); 708 MetricsCollector metricsCollector = sokletConfig.getMetricsCollector(); 709 710 // Holders to permit mutable effectively-final variables 711 AtomicReference<MarshaledResponse> marshaledResponseHolder = new AtomicReference<>(); 712 AtomicReference<Throwable> resourceMethodResolutionExceptionHolder = new AtomicReference<>(); 713 AtomicReference<Request> requestHolder = new AtomicReference<>(request); 714 AtomicReference<ResourceMethod> resourceMethodHolder = new AtomicReference<>(); 715 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 716 717 // Holders to permit mutable effectively-final state tracking 718 AtomicBoolean willStartResponseWritingCompleted = new AtomicBoolean(false); 719 AtomicBoolean didFinishResponseWritingCompleted = new AtomicBoolean(false); 720 AtomicBoolean didFinishRequestHandlingCompleted = new AtomicBoolean(false); 721 AtomicBoolean didInvokeWrapRequestConsumer = new AtomicBoolean(false); 722 723 List<Throwable> throwables = new ArrayList<>(10); 724 725 Consumer<LogEvent> safelyLog = (logEvent -> { 726 try { 727 lifecycleObserver.didReceiveLogEvent(logEvent); 728 } catch (Throwable throwable) { 729 throwable.printStackTrace(); 730 throwables.add(throwable); 731 } 732 }); 733 734 BiConsumer<String, Consumer<MetricsCollector>> safelyCollectMetrics = (message, metricsInvocation) -> { 735 if (metricsCollector == null) 736 return; 737 738 try { 739 metricsInvocation.accept(metricsCollector); 740 } catch (Throwable throwable) { 741 safelyLog.accept(LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED, message) 742 .throwable(throwable) 743 .request(requestHolder.get()) 744 .resourceMethod(resourceMethodHolder.get()) 745 .marshaledResponse(marshaledResponseHolder.get()) 746 .build()); 747 } 748 }; 749 750 requestHolder.set(request); 751 752 try { 753 requestInterceptor.wrapRequest(serverType, request, (wrappedRequest) -> { 754 didInvokeWrapRequestConsumer.set(true); 755 requestHolder.set(wrappedRequest); 756 757 try { 758 // Resolve after wrapping so path/method rewrites affect routing. 759 resourceMethodHolder.set(resourceMethodResolver.resourceMethodForRequest(requestHolder.get(), serverType).orElse(null)); 760 resourceMethodResolutionExceptionHolder.set(null); 761 } catch (Throwable t) { 762 safelyLog.accept(LogEvent.with(LogEventType.RESOURCE_METHOD_RESOLUTION_FAILED, "Unable to resolve Resource Method") 763 .throwable(t) 764 .request(requestHolder.get()) 765 .build()); 766 767 // If an exception occurs here, keep track of it - we will surface them after letting LifecycleObserver 768 // see that a request has come in. 769 throwables.add(t); 770 resourceMethodResolutionExceptionHolder.set(t); 771 resourceMethodHolder.set(null); 772 } 773 774 try { 775 lifecycleObserver.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get()); 776 } catch (Throwable t) { 777 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_START_REQUEST_HANDLING_FAILED, 778 format("An exception occurred while invoking %s::didStartRequestHandling", 779 LifecycleObserver.class.getSimpleName())) 780 .throwable(t) 781 .request(requestHolder.get()) 782 .resourceMethod(resourceMethodHolder.get()) 783 .build()); 784 785 throwables.add(t); 786 } 787 788 safelyCollectMetrics.accept( 789 format("An exception occurred while invoking %s::didStartRequestHandling", MetricsCollector.class.getSimpleName()), 790 (metricsInvocation) -> metricsInvocation.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get())); 791 792 try { 793 AtomicBoolean didInvokeMarshaledResponseConsumer = new AtomicBoolean(false); 794 795 requestInterceptor.interceptRequest(serverType, requestHolder.get(), resourceMethodHolder.get(), (interceptorRequest) -> { 796 requestHolder.set(interceptorRequest); 797 798 try { 799 if (resourceMethodResolutionExceptionHolder.get() != null) 800 throw resourceMethodResolutionExceptionHolder.get(); 801 802 HttpRequestResult requestResult = toHttpRequestResult(requestHolder.get(), resourceMethodHolder.get(), serverType); 803 requestResultHolder.set(requestResult); 804 805 MarshaledResponse originalMarshaledResponse = requestResult.getMarshaledResponse(); 806 MarshaledResponse updatedMarshaledResponse = requestResult.getMarshaledResponse(); 807 808 // A few special cases that are "global" in that they can affect all requests and 809 // need to happen after marshaling the response... 810 811 // 1. Customize response for HEAD (e.g. remove body, set Content-Length header) 812 updatedMarshaledResponse = applyHeadResponseIfApplicable(requestHolder.get(), updatedMarshaledResponse); 813 814 // 2. Apply other standard response customizations (CORS, Content-Length) 815 // Note that we don't want to write Content-Length for SSE "accepted" handshakes 816 SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null); 817 boolean suppressContentLength = sseHandshakeResult != null && sseHandshakeResult instanceof SseHandshakeResult.Accepted; 818 819 updatedMarshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), updatedMarshaledResponse, suppressContentLength); 820 821 // Update our result holder with the modified response if necessary 822 if (originalMarshaledResponse != updatedMarshaledResponse) { 823 marshaledResponseHolder.set(updatedMarshaledResponse); 824 requestResultHolder.set(requestResult.copy() 825 .marshaledResponse(updatedMarshaledResponse) 826 .finish()); 827 } 828 829 return updatedMarshaledResponse; 830 } catch (Throwable t) { 831 if (t != resourceMethodResolutionExceptionHolder.get()) { 832 throwables.add(t); 833 834 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_PROCESSING_FAILED, 835 "An exception occurred while processing request") 836 .throwable(t) 837 .request(requestHolder.get()) 838 .resourceMethod(resourceMethodHolder.get()) 839 .build()); 840 } 841 842 // Unhappy path. Try to use configuration's exception response marshaler... 843 try { 844 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 845 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 846 marshaledResponseHolder.set(marshaledResponse); 847 848 return marshaledResponse; 849 } catch (Throwable t2) { 850 throwables.add(t2); 851 852 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 853 format("An exception occurred while trying to write an exception response for %s", t)) 854 .throwable(t2) 855 .request(requestHolder.get()) 856 .resourceMethod(resourceMethodHolder.get()) 857 .build()); 858 859 // The configuration's exception response marshaler failed - provide a failsafe response to recover 860 return provideFailsafeMarshaledResponse(requestHolder.get(), t2); 861 } 862 } 863 }, (interceptorMarshaledResponse) -> { 864 requireNonNull(interceptorMarshaledResponse); 865 didInvokeMarshaledResponseConsumer.set(true); 866 marshaledResponseHolder.set(interceptorMarshaledResponse); 867 }); 868 869 if (!didInvokeMarshaledResponseConsumer.get()) { 870 requestResultHolder.set(null); 871 throw new IllegalStateException(format("%s::interceptRequest must call responseWriter", RequestInterceptor.class.getSimpleName())); 872 } 873 } catch (Throwable t) { 874 throwables.add(t); 875 876 try { 877 // In the event that an error occurs during processing of a RequestInterceptor method, for example 878 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_INTERCEPT_REQUEST_FAILED, 879 format("An exception occurred while invoking %s::interceptRequest", RequestInterceptor.class.getSimpleName())) 880 .throwable(t) 881 .request(requestHolder.get()) 882 .resourceMethod(resourceMethodHolder.get()) 883 .build()); 884 885 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 886 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 887 marshaledResponseHolder.set(marshaledResponse); 888 } catch (Throwable t2) { 889 throwables.add(t2); 890 891 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 892 format("An exception occurred while invoking %s::forThrowable when trying to write an exception response for %s", ResponseMarshaler.class.getSimpleName(), t)) 893 .throwable(t2) 894 .request(requestHolder.get()) 895 .resourceMethod(resourceMethodHolder.get()) 896 .build()); 897 898 marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t2)); 899 } 900 } finally { 901 try { 902 try { 903 lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()); 904 } finally { 905 willStartResponseWritingCompleted.set(true); 906 } 907 908 safelyCollectMetrics.accept( 909 format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()), 910 (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get())); 911 912 Instant responseWriteStarted = Instant.now(); 913 914 try { 915 HttpRequestResult requestResult = requestResultHolder.get(); 916 917 if (requestResult != null) 918 requestResultConsumer.accept(requestResult); 919 else 920 requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get()) 921 .resourceMethod(resourceMethodHolder.get()) 922 .build()); 923 924 Instant responseWriteFinished = Instant.now(); 925 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 926 927 try { 928 lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration); 929 } catch (Throwable t) { 930 throwables.add(t); 931 932 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 933 format("An exception occurred while invoking %s::didWriteResponse", 934 LifecycleObserver.class.getSimpleName())) 935 .throwable(t) 936 .request(requestHolder.get()) 937 .resourceMethod(resourceMethodHolder.get()) 938 .marshaledResponse(marshaledResponseHolder.get()) 939 .build()); 940 } finally { 941 didFinishResponseWritingCompleted.set(true); 942 } 943 944 safelyCollectMetrics.accept( 945 format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()), 946 (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 947 marshaledResponseHolder.get(), responseWriteDuration)); 948 } catch (Throwable t) { 949 throwables.add(t); 950 951 Instant responseWriteFinished = Instant.now(); 952 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 953 954 try { 955 lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t); 956 } catch (Throwable t2) { 957 throwables.add(t2); 958 959 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 960 format("An exception occurred while invoking %s::didFailToWriteResponse", 961 LifecycleObserver.class.getSimpleName())) 962 .throwable(t2) 963 .request(requestHolder.get()) 964 .resourceMethod(resourceMethodHolder.get()) 965 .marshaledResponse(marshaledResponseHolder.get()) 966 .build()); 967 } 968 969 safelyCollectMetrics.accept( 970 format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()), 971 (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 972 marshaledResponseHolder.get(), responseWriteDuration, t)); 973 } 974 } finally { 975 Duration processingDuration = Duration.between(processingStarted, Instant.now()); 976 977 safelyCollectMetrics.accept( 978 format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()), 979 (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables))); 980 981 try { 982 lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)); 983 } catch (Throwable t) { 984 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED, 985 format("An exception occurred while invoking %s::didFinishRequestHandling", 986 LifecycleObserver.class.getSimpleName())) 987 .throwable(t) 988 .request(requestHolder.get()) 989 .resourceMethod(resourceMethodHolder.get()) 990 .marshaledResponse(marshaledResponseHolder.get()) 991 .build()); 992 } finally { 993 didFinishRequestHandlingCompleted.set(true); 994 } 995 } 996 } 997 }); 998 999 if (!didInvokeWrapRequestConsumer.get()) 1000 throw new IllegalStateException(format("%s::wrapRequest must call requestProcessor", RequestInterceptor.class.getSimpleName())); 1001 } catch (Throwable t) { 1002 // If an error occurred during request wrapping, it's possible a response was never written/communicated back to LifecycleObserver. 1003 // Detect that here and inform LifecycleObserver accordingly. 1004 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_WRAP_REQUEST_FAILED, 1005 format("An exception occurred while invoking %s::wrapRequest", 1006 RequestInterceptor.class.getSimpleName())) 1007 .throwable(t) 1008 .request(requestHolder.get()) 1009 .resourceMethod(resourceMethodHolder.get()) 1010 .marshaledResponse(marshaledResponseHolder.get()) 1011 .build()); 1012 1013 // If we don't have a response, let the marshaler try to make one for the exception. 1014 // If that fails, use the failsafe. 1015 if (marshaledResponseHolder.get() == null) { 1016 try { 1017 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 1018 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 1019 marshaledResponseHolder.set(marshaledResponse); 1020 } catch (Throwable t2) { 1021 throwables.add(t2); 1022 1023 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 1024 format("An exception occurred during request wrapping while invoking %s::forThrowable", 1025 ResponseMarshaler.class.getSimpleName())) 1026 .throwable(t2) 1027 .request(requestHolder.get()) 1028 .resourceMethod(resourceMethodHolder.get()) 1029 .marshaledResponse(marshaledResponseHolder.get()) 1030 .build()); 1031 1032 marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t)); 1033 } 1034 } 1035 1036 if (!willStartResponseWritingCompleted.get()) { 1037 try { 1038 lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()); 1039 } catch (Throwable t2) { 1040 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_WILL_WRITE_RESPONSE_FAILED, 1041 format("An exception occurred while invoking %s::willWriteResponse", 1042 LifecycleObserver.class.getSimpleName())) 1043 .throwable(t2) 1044 .request(requestHolder.get()) 1045 .resourceMethod(resourceMethodHolder.get()) 1046 .marshaledResponse(marshaledResponseHolder.get()) 1047 .build()); 1048 } 1049 1050 safelyCollectMetrics.accept( 1051 format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()), 1052 (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get())); 1053 } 1054 1055 try { 1056 Instant responseWriteStarted = Instant.now(); 1057 1058 if (!didFinishResponseWritingCompleted.get()) { 1059 try { 1060 HttpRequestResult requestResult = requestResultHolder.get(); 1061 1062 if (requestResult != null) 1063 requestResultConsumer.accept(requestResult); 1064 else 1065 requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get()) 1066 .resourceMethod(resourceMethodHolder.get()) 1067 .build()); 1068 1069 Instant responseWriteFinished = Instant.now(); 1070 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 1071 1072 try { 1073 lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration); 1074 } catch (Throwable t2) { 1075 throwables.add(t2); 1076 1077 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 1078 format("An exception occurred while invoking %s::didWriteResponse", 1079 LifecycleObserver.class.getSimpleName())) 1080 .throwable(t2) 1081 .request(requestHolder.get()) 1082 .resourceMethod(resourceMethodHolder.get()) 1083 .marshaledResponse(marshaledResponseHolder.get()) 1084 .build()); 1085 } 1086 1087 safelyCollectMetrics.accept( 1088 format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()), 1089 (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 1090 marshaledResponseHolder.get(), responseWriteDuration)); 1091 } catch (Throwable t2) { 1092 throwables.add(t2); 1093 1094 Instant responseWriteFinished = Instant.now(); 1095 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 1096 1097 try { 1098 lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t); 1099 } catch (Throwable t3) { 1100 throwables.add(t3); 1101 1102 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 1103 format("An exception occurred while invoking %s::didFailToWriteResponse", 1104 LifecycleObserver.class.getSimpleName())) 1105 .throwable(t3) 1106 .request(requestHolder.get()) 1107 .resourceMethod(resourceMethodHolder.get()) 1108 .marshaledResponse(marshaledResponseHolder.get()) 1109 .build()); 1110 } 1111 1112 safelyCollectMetrics.accept( 1113 format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()), 1114 (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 1115 marshaledResponseHolder.get(), responseWriteDuration, t)); 1116 } 1117 } 1118 } finally { 1119 if (!didFinishRequestHandlingCompleted.get()) { 1120 Duration processingDuration = Duration.between(processingStarted, Instant.now()); 1121 1122 safelyCollectMetrics.accept( 1123 format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()), 1124 (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables))); 1125 1126 try { 1127 lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)); 1128 } catch (Throwable t2) { 1129 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED, 1130 format("An exception occurred while invoking %s::didFinishRequestHandling", 1131 LifecycleObserver.class.getSimpleName())) 1132 .throwable(t2) 1133 .request(requestHolder.get()) 1134 .resourceMethod(resourceMethodHolder.get()) 1135 .marshaledResponse(marshaledResponseHolder.get()) 1136 .build()); 1137 } 1138 } 1139 } 1140 } 1141 } 1142 1143 @NonNull 1144 protected HttpRequestResult toHttpRequestResult(@NonNull Request request, 1145 @Nullable ResourceMethod resourceMethod, 1146 @NonNull ServerType serverType) throws Throwable { 1147 requireNonNull(request); 1148 requireNonNull(serverType); 1149 1150 ResourceMethodParameterProvider resourceMethodParameterProvider = getSokletConfig().getResourceMethodParameterProvider(); 1151 InstanceProvider instanceProvider = getSokletConfig().getInstanceProvider(); 1152 CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer(); 1153 ResourceMethodResolver resourceMethodResolver = getSokletConfig().getResourceMethodResolver(); 1154 ResponseMarshaler responseMarshaler = getSokletConfig().getResponseMarshaler(); 1155 CorsPreflight corsPreflight = request.getCorsPreflight().orElse(null); 1156 1157 // Special short-circuit for big requests 1158 if (request.isContentTooLarge()) 1159 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forContentTooLarge(request, resourceMethodResolver.resourceMethodForRequest(request, serverType).orElse(null))) 1160 .resourceMethod(resourceMethod) 1161 .build(); 1162 1163 // Special short-circuit for OPTIONS * 1164 if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH) 1165 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptionsSplat(request)).build(); 1166 1167 // No resource method was found for this HTTP method and path. 1168 if (resourceMethod == null) { 1169 // If this was an OPTIONS request, do special processing. 1170 // If not, figure out if we should return a 404 or 405. 1171 if (request.getHttpMethod() == HttpMethod.OPTIONS) { 1172 // See what methods are available to us for this request's path 1173 Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType); 1174 1175 // Special handling for CORS preflight requests, if needed 1176 if (corsPreflight != null) { 1177 // Let configuration function determine if we should authorize this request. 1178 // Discard any OPTIONS references - see https://stackoverflow.com/a/68529748 1179 Map<HttpMethod, ResourceMethod> nonOptionsMatchingResourceMethodsByHttpMethod = matchingResourceMethodsByHttpMethod.entrySet().stream() 1180 .filter(entry -> entry.getKey() != HttpMethod.OPTIONS) 1181 .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); 1182 1183 CorsPreflightResponse corsPreflightResponse = corsAuthorizer.authorizePreflight(request, corsPreflight, nonOptionsMatchingResourceMethodsByHttpMethod).orElse(null); 1184 1185 // Allow or reject CORS depending on what the function said to do 1186 if (corsPreflightResponse != null) { 1187 // Allow 1188 MarshaledResponse marshaledResponse = responseMarshaler.forCorsPreflightAllowed(request, corsPreflight, corsPreflightResponse); 1189 1190 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1191 .corsPreflightResponse(corsPreflightResponse) 1192 .resourceMethod(resourceMethod) 1193 .build(); 1194 } 1195 1196 // Reject 1197 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forCorsPreflightRejected(request, corsPreflight)) 1198 .resourceMethod(resourceMethod) 1199 .build(); 1200 } else { 1201 // Just a normal OPTIONS response (non-CORS-preflight). 1202 // If there's a matching OPTIONS resource method for this OPTIONS request, then invoke it. 1203 ResourceMethod optionsResourceMethod = matchingResourceMethodsByHttpMethod.get(HttpMethod.OPTIONS); 1204 1205 if (optionsResourceMethod != null) { 1206 resourceMethod = optionsResourceMethod; 1207 } else { 1208 Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(matchingResourceMethodsByHttpMethod, true); 1209 1210 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptions(request, allowedHttpMethods)) 1211 .resourceMethod(resourceMethod) 1212 .build(); 1213 } 1214 } 1215 } else if (request.getHttpMethod() == HttpMethod.HEAD) { 1216 // If there's a matching GET resource method for this HEAD request, then invoke it 1217 Request headGetRequest = request.copy().httpMethod(HttpMethod.GET).finish(); 1218 ResourceMethod headGetResourceMethod = resourceMethodResolver.resourceMethodForRequest(headGetRequest, serverType).orElse(null); 1219 1220 if (headGetResourceMethod != null) 1221 resourceMethod = headGetResourceMethod; 1222 else 1223 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request)) 1224 .resourceMethod(resourceMethod) 1225 .build(); 1226 } else { 1227 // Not an OPTIONS request, so it's possible we have a 405. See if other HTTP methods match... 1228 Map<HttpMethod, ResourceMethod> otherMatchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType); 1229 1230 Set<HttpMethod> matchingNonOptionsHttpMethods = otherMatchingResourceMethodsByHttpMethod.keySet().stream() 1231 .filter(httpMethod -> httpMethod != HttpMethod.OPTIONS) 1232 .collect(Collectors.toSet()); 1233 1234 if (matchingNonOptionsHttpMethods.size() > 0) { 1235 // ...if some do, it's a 405 1236 Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(otherMatchingResourceMethodsByHttpMethod, true); 1237 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forMethodNotAllowed(request, allowedHttpMethods)) 1238 .resourceMethod(resourceMethod) 1239 .build(); 1240 } else { 1241 // no matching resource method found, it's a 404 1242 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request)) 1243 .resourceMethod(resourceMethod) 1244 .build(); 1245 } 1246 } 1247 } 1248 1249 // Found a resource method - happy path. 1250 // 1. Get an instance of the resource class 1251 // 2. Get values to pass to the resource method on the resource class 1252 // 3. Invoke the resource method and use its return value to drive a response 1253 Class<?> resourceClass = resourceMethod.getMethod().getDeclaringClass(); 1254 Object resourceClassInstance; 1255 1256 try { 1257 resourceClassInstance = instanceProvider.provide(resourceClass); 1258 } catch (Exception e) { 1259 throw new IllegalArgumentException(format("Unable to acquire an instance of %s", resourceClass.getName()), e); 1260 } 1261 1262 List<Object> parameterValues = resourceMethodParameterProvider.parameterValuesForResourceMethod(request, resourceMethod); 1263 1264 Object responseObject; 1265 1266 try { 1267 responseObject = resourceMethod.getMethod().invoke(resourceClassInstance, parameterValues.toArray()); 1268 } catch (InvocationTargetException e) { 1269 if (e.getTargetException() != null) 1270 throw e.getTargetException(); 1271 1272 throw e; 1273 } 1274 1275 // Unwrap the Optional<T>, if one exists. We do not recurse deeper than one level 1276 if (responseObject instanceof Optional<?>) 1277 responseObject = ((Optional<?>) responseObject).orElse(null); 1278 1279 Response response; 1280 SseHandshakeResult sseHandshakeResult = null; 1281 1282 // If null/void return, it's a 204 1283 // If it's a MarshaledResponse object, no marshaling + return it immediately - caller knows exactly what it wants to write. 1284 // If it's a Response object, use as is. 1285 // If it's a non-Response type of object, assume it's the response body and wrap in a Response. 1286 if (responseObject == null) { 1287 response = Response.withStatusCode(204).build(); 1288 } else if (responseObject instanceof MarshaledResponse) { 1289 MarshaledResponse marshaledResponse = (MarshaledResponse) responseObject; 1290 enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent()); 1291 1292 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1293 .resourceMethod(resourceMethod) 1294 .build(); 1295 } else if (responseObject instanceof Response) { 1296 response = (Response) responseObject; 1297 } else if (responseObject instanceof SseHandshakeResult.Accepted accepted) { // SSE "accepted" handshake 1298 return HttpRequestResult.withMarshaledResponse(toMarshaledResponse(accepted)) 1299 .resourceMethod(resourceMethod) 1300 .sseHandshakeResult(accepted) 1301 .build(); 1302 } else if (responseObject instanceof SseHandshakeResult.Rejected rejected) { // SSE "rejected" handshake 1303 response = rejected.getResponse(); 1304 sseHandshakeResult = rejected; 1305 } else { 1306 response = Response.withStatusCode(200).body(responseObject).build(); 1307 } 1308 1309 enforceBodylessStatusCode(response.getStatusCode(), response.getBody().isPresent()); 1310 1311 MarshaledResponse marshaledResponse = responseMarshaler.forResourceMethod(request, response, resourceMethod); 1312 1313 enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent()); 1314 1315 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1316 .response(response) 1317 .resourceMethod(resourceMethod) 1318 .sseHandshakeResult(sseHandshakeResult) 1319 .build(); 1320 } 1321 1322 @NonNull 1323 private MarshaledResponse toMarshaledResponse(SseHandshakeResult.@NonNull Accepted accepted) { 1324 requireNonNull(accepted); 1325 1326 Map<String, Set<String>> headers = accepted.getHeaders(); 1327 LinkedCaseInsensitiveMap<Set<String>> finalHeaders = new LinkedCaseInsensitiveMap<>(DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.size() + headers.size()); 1328 1329 // Start with defaults 1330 for (Map.Entry<String, Set<String>> e : DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.entrySet()) 1331 finalHeaders.put(e.getKey(), e.getValue()); // values already unmodifiable 1332 1333 // Overlay user-supplied headers (prefer user values on key collision) 1334 for (Map.Entry<String, Set<String>> e : headers.entrySet()) { 1335 // Defensively copy so callers can't mutate after construction 1336 Set<String> values = e.getValue() == null ? Set.of() : Set.copyOf(e.getValue()); 1337 finalHeaders.put(e.getKey(), values); 1338 } 1339 1340 return MarshaledResponse.withStatusCode(200) 1341 .headers(finalHeaders) 1342 .cookies(accepted.getCookies()) 1343 .build(); 1344 } 1345 1346 private static void enforceBodylessStatusCode(@NonNull Integer statusCode, 1347 @NonNull Boolean hasBody) { 1348 requireNonNull(statusCode); 1349 requireNonNull(hasBody); 1350 1351 if (hasBody && isBodylessStatusCode(statusCode)) 1352 throw new IllegalStateException(format("HTTP status code %d must not include a response body", statusCode)); 1353 } 1354 1355 private static boolean isBodylessStatusCode(@NonNull Integer statusCode) { 1356 requireNonNull(statusCode); 1357 return (statusCode >= 100 && statusCode < 200) || statusCode == 204 || statusCode == 304; 1358 } 1359 1360 @NonNull 1361 protected MarshaledResponse applyHeadResponseIfApplicable(@NonNull Request request, 1362 @NonNull MarshaledResponse marshaledResponse) { 1363 if (request.getHttpMethod() != HttpMethod.HEAD) 1364 return marshaledResponse; 1365 1366 return getSokletConfig().getResponseMarshaler().forHead(request, marshaledResponse); 1367 } 1368 1369 // Hat tip to Aslan Parçası and GrayStar 1370 @NonNull 1371 protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request, 1372 @NonNull MarshaledResponse marshaledResponse) { 1373 requireNonNull(request); 1374 requireNonNull(marshaledResponse); 1375 1376 return applyCommonPropertiesToMarshaledResponse(request, marshaledResponse, false); 1377 } 1378 1379 protected void handleMcpRequest(@NonNull Request request, 1380 @NonNull Consumer<HttpRequestResult> requestResultConsumer) { 1381 requireNonNull(request); 1382 requireNonNull(requestResultConsumer); 1383 1384 requestResultConsumer.accept(this.defaultMcpRuntime.handleRequest(request)); 1385 } 1386 1387 protected void handleSimulatedMcpStreamDisconnect(@NonNull Request request, 1388 @NonNull String sessionId) { 1389 requireNonNull(request); 1390 requireNonNull(sessionId); 1391 1392 this.defaultMcpRuntime.handleClientDisconnectedStream(request, sessionId); 1393 } 1394 1395 @NonNull 1396 protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request, 1397 @NonNull MarshaledResponse marshaledResponse, 1398 @NonNull Boolean suppressContentLength) { 1399 requireNonNull(request); 1400 requireNonNull(marshaledResponse); 1401 requireNonNull(suppressContentLength); 1402 1403 // Don't write Content-Length for an accepted SSE Handshake, for example 1404 if (!suppressContentLength) 1405 marshaledResponse = applyContentLengthIfApplicable(request, marshaledResponse); 1406 1407 // If the Date header is missing, add it using our cached provider 1408 if (!marshaledResponse.getHeaders().containsKey("Date")) 1409 marshaledResponse = marshaledResponse.copy() 1410 .headers(headers -> headers.put("Date", Set.of(CachedHttpDate.getCurrentValue()))) 1411 .finish(); 1412 1413 marshaledResponse = applyCorsResponseIfApplicable(request, marshaledResponse); 1414 1415 return marshaledResponse; 1416 } 1417 1418 @NonNull 1419 protected MarshaledResponse applyContentLengthIfApplicable(@NonNull Request request, 1420 @NonNull MarshaledResponse marshaledResponse) { 1421 requireNonNull(request); 1422 requireNonNull(marshaledResponse); 1423 1424 Set<String> normalizedHeaderNames = marshaledResponse.getHeaders().keySet().stream() 1425 .map(headerName -> headerName.toLowerCase(Locale.US)) 1426 .collect(Collectors.toSet()); 1427 1428 // If Content-Length is already specified, don't do anything 1429 if (normalizedHeaderNames.contains("content-length")) 1430 return marshaledResponse; 1431 1432 // If Content-Length is not specified, specify as the number of bytes in the body 1433 return marshaledResponse.copy() 1434 .headers((mutableHeaders) -> { 1435 String contentLengthHeaderValue = String.valueOf(marshaledResponse.getBodyLength()); 1436 mutableHeaders.put("Content-Length", Set.of(contentLengthHeaderValue)); 1437 }).finish(); 1438 } 1439 1440 @NonNull 1441 protected MarshaledResponse applyCorsResponseIfApplicable(@NonNull Request request, 1442 @NonNull MarshaledResponse marshaledResponse) { 1443 requireNonNull(request); 1444 requireNonNull(marshaledResponse); 1445 1446 Cors cors = request.getCors().orElse(null); 1447 1448 // If non-CORS request, nothing further to do (note that CORS preflight was handled earlier) 1449 if (cors == null) 1450 return marshaledResponse; 1451 1452 CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer(); 1453 1454 // Does the authorizer say we are authorized? 1455 CorsResponse corsResponse = corsAuthorizer.authorize(request, cors).orElse(null); 1456 1457 // Not authorized - don't apply CORS headers to the response 1458 if (corsResponse == null) 1459 return marshaledResponse; 1460 1461 // Authorized - OK, let's apply the headers to the response 1462 return getSokletConfig().getResponseMarshaler().forCorsAllowed(request, cors, corsResponse, marshaledResponse); 1463 } 1464 1465 @NonNull 1466 protected Map<@NonNull HttpMethod, @NonNull ResourceMethod> resolveMatchingResourceMethodsByHttpMethod(@NonNull Request request, 1467 @NonNull ResourceMethodResolver resourceMethodResolver, 1468 @NonNull ServerType serverType) { 1469 requireNonNull(request); 1470 requireNonNull(resourceMethodResolver); 1471 requireNonNull(serverType); 1472 1473 // Special handling for OPTIONS * 1474 if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH) 1475 return new LinkedHashMap<>(); 1476 1477 Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = new LinkedHashMap<>(HttpMethod.values().length); 1478 1479 for (HttpMethod httpMethod : HttpMethod.values()) { 1480 // Make a quick copy of the request to see if other paths match 1481 Request otherRequest = Request.withPath(httpMethod, request.getPath()).build(); 1482 ResourceMethod resourceMethod = resourceMethodResolver.resourceMethodForRequest(otherRequest, serverType).orElse(null); 1483 1484 if (resourceMethod != null) 1485 matchingResourceMethodsByHttpMethod.put(httpMethod, resourceMethod); 1486 } 1487 1488 return matchingResourceMethodsByHttpMethod; 1489 } 1490 1491 @NonNull 1492 private static Set<@NonNull HttpMethod> allowedHttpMethodsForResponse(@NonNull Map<@NonNull HttpMethod, @NonNull ResourceMethod> matchingResourceMethodsByHttpMethod, 1493 @NonNull Boolean includeOptions) { 1494 requireNonNull(matchingResourceMethodsByHttpMethod); 1495 requireNonNull(includeOptions); 1496 1497 Set<HttpMethod> allowedHttpMethods = EnumSet.noneOf(HttpMethod.class); 1498 allowedHttpMethods.addAll(matchingResourceMethodsByHttpMethod.keySet()); 1499 1500 if (includeOptions) 1501 allowedHttpMethods.add(HttpMethod.OPTIONS); 1502 1503 if (matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.GET) || matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.HEAD)) 1504 allowedHttpMethods.add(HttpMethod.HEAD); 1505 1506 return allowedHttpMethods; 1507 } 1508 1509 @NonNull 1510 protected MarshaledResponse provideFailsafeMarshaledResponse(@NonNull Request request, 1511 @NonNull Throwable throwable) { 1512 requireNonNull(request); 1513 requireNonNull(throwable); 1514 1515 Integer statusCode = 500; 1516 Charset charset = StandardCharsets.UTF_8; 1517 1518 return MarshaledResponse.withStatusCode(statusCode) 1519 .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", charset.name())))) 1520 .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(charset)) 1521 .build(); 1522 } 1523 1524 /** 1525 * Synonym for {@link #stop()}. 1526 */ 1527 @Override 1528 public void close() { 1529 stop(); 1530 } 1531 1532 /** 1533 * Is any managed transport server started? 1534 * 1535 * @return {@code true} if at least one configured transport server is started, {@code false} otherwise 1536 */ 1537 @NonNull 1538 public Boolean isStarted() { 1539 getLock().lock(); 1540 1541 try { 1542 HttpServer httpServer = getSokletConfig().getHttpServer().orElse(null); 1543 1544 if (httpServer != null && httpServer.isStarted()) 1545 return true; 1546 1547 SseServer sseServer = getSokletConfig().getSseServer().orElse(null); 1548 if (sseServer != null && sseServer.isStarted()) 1549 return true; 1550 1551 McpServer mcpServer = getSokletConfig().getMcpServer().orElse(null); 1552 return mcpServer != null && mcpServer.isStarted(); 1553 } finally { 1554 getLock().unlock(); 1555 } 1556 } 1557 1558 /** 1559 * Runs Soklet with special non-network "simulator" implementations of the configured transport servers - useful for integration testing. 1560 * <p> 1561 * See <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a> for how to write these tests. 1562 * 1563 * @param sokletConfig configuration that drives the Soklet system 1564 * @param simulatorConsumer code to execute within the context of the simulator 1565 */ 1566 public static void runSimulator(@NonNull SokletConfig sokletConfig, 1567 @NonNull Consumer<Simulator> simulatorConsumer) { 1568 requireNonNull(sokletConfig); 1569 requireNonNull(simulatorConsumer); 1570 1571 // Create Soklet instance - this initializes the REAL implementations through proxies 1572 Soklet soklet = Soklet.fromConfig(sokletConfig); 1573 1574 // Extract proxies (they're guaranteed to be proxies now) 1575 HttpServerProxy serverProxy = sokletConfig.getHttpServer() 1576 .map(server -> (HttpServerProxy) server) 1577 .orElse(null); 1578 SseServerProxy sseServerProxy = sokletConfig.getSseServer() 1579 .map(s -> (SseServerProxy) s) 1580 .orElse(null); 1581 McpServerProxy mcpServerProxy = sokletConfig.getMcpServer() 1582 .map(mcpServer -> (McpServerProxy) mcpServer) 1583 .orElse(null); 1584 1585 // Create mock implementations 1586 MockHttpServer mockServer = serverProxy == null ? null : new MockHttpServer(); 1587 MockSseServer mockSseServer = new MockSseServer(); 1588 MockMcpServer mockMcpServer = mcpServerProxy == null ? null : new MockMcpServer(mcpServerProxy.getRealImplementation()); 1589 1590 // Switch proxies to simulator mode 1591 if (serverProxy != null) 1592 serverProxy.enableSimulatorMode(mockServer); 1593 1594 if (sseServerProxy != null) 1595 sseServerProxy.enableSimulatorMode(mockSseServer); 1596 1597 if (mcpServerProxy != null) 1598 mcpServerProxy.enableSimulatorMode(mockMcpServer); 1599 1600 try { 1601 // Initialize mocks with request handlers that delegate to Soklet's processing 1602 if (mockServer != null) 1603 mockServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 1604 // Delegate to Soklet's internal request handling 1605 soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer); 1606 }); 1607 1608 if (mockSseServer != null) 1609 mockSseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 1610 // Delegate to Soklet's internal request handling for SSE 1611 soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer); 1612 }); 1613 1614 if (mockMcpServer != null) 1615 mockMcpServer.initialize(sokletConfig, soklet::handleMcpRequest); 1616 1617 if (mockMcpServer != null) 1618 mockMcpServer.onClientDisconnectedMcpStream(soklet::handleSimulatedMcpStreamDisconnect); 1619 1620 // Create and provide simulator 1621 Simulator simulator = new DefaultSimulator(mockServer, mockSseServer, mockMcpServer); 1622 simulatorConsumer.accept(simulator); 1623 } finally { 1624 // Always restore to real implementations 1625 if (serverProxy != null) 1626 serverProxy.disableSimulatorMode(); 1627 1628 if (sseServerProxy != null) 1629 sseServerProxy.disableSimulatorMode(); 1630 1631 if (mcpServerProxy != null) 1632 mcpServerProxy.disableSimulatorMode(); 1633 } 1634 } 1635 1636 @NonNull 1637 protected SokletConfig getSokletConfig() { 1638 return this.sokletConfig; 1639 } 1640 1641 @NonNull 1642 protected ReentrantLock getLock() { 1643 return this.lock; 1644 } 1645 1646 @NonNull 1647 protected AtomicReference<CountDownLatch> getAwaitShutdownLatchReference() { 1648 return this.awaitShutdownLatchReference; 1649 } 1650 1651 @ThreadSafe 1652 static class DefaultSimulator implements Simulator { 1653 @Nullable 1654 private MockHttpServer server; 1655 @Nullable 1656 private MockSseServer sseServer; 1657 @Nullable 1658 private MockMcpServer mcpServer; 1659 1660 public DefaultSimulator(@Nullable MockHttpServer server, 1661 @Nullable MockSseServer sseServer, 1662 @Nullable MockMcpServer mcpServer) { 1663 this.server = server; 1664 this.sseServer = sseServer; 1665 this.mcpServer = mcpServer; 1666 } 1667 1668 @NonNull 1669 @Override 1670 public HttpRequestResult performHttpRequest(@NonNull Request request) { 1671 MockHttpServer server = getHttpServer().orElse(null); 1672 1673 if (server == null) 1674 throw new IllegalStateException(format("You must specify a %s in your %s to simulate requests", 1675 HttpServer.class.getSimpleName(), SokletConfig.class.getSimpleName())); 1676 1677 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 1678 HttpServer.RequestHandler requestHandler = server.getRequestHandler().orElse(null); 1679 1680 if (requestHandler == null) 1681 throw new IllegalStateException("You must register a request handler prior to simulating requests"); 1682 1683 requestHandler.handleRequest(request, (requestResult -> { 1684 requestResultHolder.set(requestResult); 1685 })); 1686 1687 return requestResultHolder.get(); 1688 } 1689 1690 @NonNull 1691 @Override 1692 public SseRequestResult performSseRequest(@NonNull Request request) { 1693 MockSseServer sseServer = getSseServer().orElse(null); 1694 1695 if (sseServer == null) 1696 throw new IllegalStateException(format("You must specify a %s in your %s to simulate Server-Sent Event requests", 1697 SseServer.class.getSimpleName(), SokletConfig.class.getSimpleName())); 1698 1699 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 1700 SseServer.RequestHandler requestHandler = sseServer.getRequestHandler().orElse(null); 1701 1702 if (requestHandler == null) 1703 throw new IllegalStateException("You must register a request handler prior to simulating SSE Event Source requests"); 1704 1705 requestHandler.handleRequest(request, (requestResult -> { 1706 requestResultHolder.set(requestResult); 1707 })); 1708 1709 HttpRequestResult requestResult = requestResultHolder.get(); 1710 SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null); 1711 1712 if (sseHandshakeResult == null) 1713 return new SseRequestResult.RequestFailed(requestResult); 1714 1715 if (sseHandshakeResult instanceof SseHandshakeResult.Accepted acceptedHandshake) { 1716 Consumer<SseUnicaster> clientInitializer = acceptedHandshake.getClientInitializer().orElse(null); 1717 1718 // Create a synthetic logical response using values from the accepted handshake 1719 if (requestResult.getResponse().isEmpty()) 1720 requestResult = requestResult.copy() 1721 .response(Response.withStatusCode(200) 1722 .headers(acceptedHandshake.getHeaders()) 1723 .cookies(acceptedHandshake.getCookies()) 1724 .build()) 1725 .finish(); 1726 1727 HandshakeAccepted handshakeAccepted = new HandshakeAccepted(acceptedHandshake, request.getResourcePath(), requestResult, this, clientInitializer); 1728 return handshakeAccepted; 1729 } 1730 1731 if (sseHandshakeResult instanceof SseHandshakeResult.Rejected rejectedHandshake) 1732 return new HandshakeRejected(rejectedHandshake, requestResult); 1733 1734 throw new IllegalStateException(format("Encountered unexpected %s: %s", SseHandshakeResult.class.getSimpleName(), sseHandshakeResult)); 1735 } 1736 1737 @NonNull 1738 @Override 1739 public McpRequestResult performMcpRequest(@NonNull Request request) { 1740 requireNonNull(request); 1741 1742 MockMcpServer mcpServer = getMcpServer().orElse(null); 1743 1744 if (mcpServer == null) 1745 throw new IllegalStateException(format("You must specify an MCP server in your %s to simulate MCP requests", 1746 SokletConfig.class.getSimpleName())); 1747 1748 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 1749 McpServer.RequestHandler requestHandler = mcpServer.getRequestHandler().orElse(null); 1750 1751 if (requestHandler == null) 1752 throw new IllegalStateException("You must register a request handler prior to simulating MCP requests"); 1753 1754 requestHandler.handleRequest(request, requestResultHolder::set); 1755 1756 HttpRequestResult requestResult = requestResultHolder.get(); 1757 1758 if (requestResult == null) 1759 throw new IllegalStateException("No MCP request result was produced by the simulator"); 1760 1761 if (extractContentTypeFromHeaders(requestResult.getMarshaledResponse().getHeaders()) 1762 .filter(contentType -> contentType.equalsIgnoreCase("text/event-stream")) 1763 .isPresent()) { 1764 McpRequestResult.StreamOpened streamOpened = new McpRequestResult.StreamOpened( 1765 requestResult, 1766 mcpServer.getMcpStreamErrorHandler(), 1767 requestResult.isMcpStreamClosedAfterReplay()); 1768 1769 for (McpObject mcpStreamMessage : requestResult.getMcpStreamMessages()) 1770 streamOpened.emitMessage(mcpStreamMessage); 1771 1772 if (!requestResult.isMcpStreamClosedAfterReplay()) 1773 request.getHeader("MCP-Session-Id").ifPresent(sessionId -> mcpServer.registerOpenStream(sessionId, request, streamOpened)); 1774 1775 return streamOpened; 1776 } 1777 1778 if (request.getHttpMethod() == HttpMethod.DELETE 1779 && Objects.equals(requestResult.getMarshaledResponse().getStatusCode(), 204)) 1780 request.getHeader("MCP-Session-Id").ifPresent(mcpServer::terminateStreamsForSession); 1781 1782 return new McpRequestResult.ResponseCompleted(requestResult); 1783 } 1784 1785 @NonNull 1786 @Override 1787 public Simulator onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) { 1788 MockSseServer sseServer = getSseServer().orElse(null); 1789 1790 if (sseServer != null) 1791 sseServer.onBroadcastError(onBroadcastError); 1792 1793 return this; 1794 } 1795 1796 @NonNull 1797 @Override 1798 public Simulator onUnicastError(@Nullable Consumer<Throwable> onUnicastError) { 1799 MockSseServer sseServer = getSseServer().orElse(null); 1800 1801 if (sseServer != null) 1802 sseServer.onUnicastError(onUnicastError); 1803 1804 return this; 1805 } 1806 1807 @NonNull 1808 @Override 1809 public Simulator onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) { 1810 MockMcpServer mcpServer = getMcpServer().orElse(null); 1811 1812 if (mcpServer != null) 1813 mcpServer.onMcpStreamError(onMcpStreamError); 1814 1815 return this; 1816 } 1817 1818 @NonNull 1819 Optional<MockHttpServer> getHttpServer() { 1820 return Optional.ofNullable(this.server); 1821 } 1822 1823 @NonNull 1824 Optional<MockSseServer> getSseServer() { 1825 return Optional.ofNullable(this.sseServer); 1826 } 1827 1828 @NonNull 1829 Optional<MockMcpServer> getMcpServer() { 1830 return Optional.ofNullable(this.mcpServer); 1831 } 1832 } 1833 1834 /** 1835 * Mock server that doesn't touch the network at all, useful for testing. 1836 * 1837 * @author <a href="https://www.revetkn.com">Mark Allen</a> 1838 */ 1839 @ThreadSafe 1840 static class MockHttpServer implements HttpServer { 1841 @Nullable 1842 private SokletConfig sokletConfig; 1843 private HttpServer.@Nullable RequestHandler requestHandler; 1844 1845 @Override 1846 public void start() { 1847 // No-op 1848 } 1849 1850 @Override 1851 public void stop() { 1852 // No-op 1853 } 1854 1855 @NonNull 1856 @Override 1857 public Boolean isStarted() { 1858 return true; 1859 } 1860 1861 @Override 1862 public void initialize(@NonNull SokletConfig sokletConfig, 1863 @NonNull RequestHandler requestHandler) { 1864 requireNonNull(sokletConfig); 1865 requireNonNull(requestHandler); 1866 1867 this.sokletConfig = sokletConfig; 1868 this.requestHandler = requestHandler; 1869 } 1870 1871 @NonNull 1872 protected Optional<SokletConfig> getSokletConfig() { 1873 return Optional.ofNullable(this.sokletConfig); 1874 } 1875 1876 @NonNull 1877 protected Optional<RequestHandler> getRequestHandler() { 1878 return Optional.ofNullable(this.requestHandler); 1879 } 1880 } 1881 1882 /** 1883 * Mock MCP server that doesn't touch the network at all, useful for testing. 1884 */ 1885 @ThreadSafe 1886 static class MockMcpServer implements McpServer, InternalMcpSessionMessagePublisher { 1887 @NonNull 1888 private final McpServer realImplementation; 1889 @Nullable 1890 private SokletConfig sokletConfig; 1891 private McpServer.@Nullable RequestHandler requestHandler; 1892 @NonNull 1893 private final AtomicReference<Consumer<Throwable>> mcpStreamErrorHandler; 1894 @NonNull 1895 private final ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> openStreamsBySessionId; 1896 @NonNull 1897 private final AtomicReference<@Nullable BiConsumer<Request, String>> clientDisconnectedMcpStreamHandler; 1898 1899 public MockMcpServer(@NonNull McpServer realImplementation) { 1900 requireNonNull(realImplementation); 1901 1902 this.realImplementation = realImplementation; 1903 this.mcpStreamErrorHandler = new AtomicReference<>(); 1904 this.openStreamsBySessionId = new ConcurrentHashMap<>(); 1905 this.clientDisconnectedMcpStreamHandler = new AtomicReference<>(); 1906 } 1907 1908 @Override 1909 public void start() { 1910 // No-op 1911 } 1912 1913 @Override 1914 public void stop() { 1915 // No-op 1916 } 1917 1918 @NonNull 1919 @Override 1920 public Boolean isStarted() { 1921 return true; 1922 } 1923 1924 @Override 1925 public void initialize(@NonNull SokletConfig sokletConfig, 1926 @NonNull RequestHandler requestHandler) { 1927 requireNonNull(sokletConfig); 1928 requireNonNull(requestHandler); 1929 1930 this.sokletConfig = sokletConfig; 1931 this.requestHandler = requestHandler; 1932 } 1933 1934 @NonNull 1935 @Override 1936 public McpHandlerResolver getHandlerResolver() { 1937 return getRealImplementation().getHandlerResolver(); 1938 } 1939 1940 @NonNull 1941 @Override 1942 public McpRequestAdmissionPolicy getRequestAdmissionPolicy() { 1943 return getRealImplementation().getRequestAdmissionPolicy(); 1944 } 1945 1946 @NonNull 1947 @Override 1948 public McpRequestInterceptor getRequestInterceptor() { 1949 return getRealImplementation().getRequestInterceptor(); 1950 } 1951 1952 @NonNull 1953 @Override 1954 public McpResponseMarshaler getResponseMarshaler() { 1955 return getRealImplementation().getResponseMarshaler(); 1956 } 1957 1958 @NonNull 1959 @Override 1960 public McpCorsAuthorizer getCorsAuthorizer() { 1961 return getRealImplementation().getCorsAuthorizer(); 1962 } 1963 1964 @NonNull 1965 @Override 1966 public McpSessionStore getSessionStore() { 1967 return getRealImplementation().getSessionStore(); 1968 } 1969 1970 @NonNull 1971 @Override 1972 public IdGenerator<String> getIdGenerator() { 1973 return getRealImplementation().getIdGenerator(); 1974 } 1975 1976 @NonNull 1977 protected McpServer getRealImplementation() { 1978 return this.realImplementation; 1979 } 1980 1981 @NonNull 1982 protected Optional<SokletConfig> getSokletConfig() { 1983 return Optional.ofNullable(this.sokletConfig); 1984 } 1985 1986 @NonNull 1987 protected Optional<RequestHandler> getRequestHandler() { 1988 return Optional.ofNullable(this.requestHandler); 1989 } 1990 1991 protected void onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) { 1992 this.mcpStreamErrorHandler.set(onMcpStreamError); 1993 } 1994 1995 @NonNull 1996 protected AtomicReference<Consumer<Throwable>> getMcpStreamErrorHandler() { 1997 return this.mcpStreamErrorHandler; 1998 } 1999 2000 protected void registerOpenStream(@NonNull String sessionId, 2001 @NonNull Request request, 2002 McpRequestResult.StreamOpened streamOpened) { 2003 requireNonNull(sessionId); 2004 requireNonNull(request); 2005 requireNonNull(streamOpened); 2006 2007 getOpenStreamsBySessionId() 2008 .computeIfAbsent(sessionId, ignored -> new CopyOnWriteArrayList<>()) 2009 .add(streamOpened); 2010 2011 streamOpened.onClose(() -> closeOpenStream(sessionId, request, streamOpened)); 2012 } 2013 2014 protected void terminateStreamsForSession(@NonNull String sessionId) { 2015 requireNonNull(sessionId); 2016 2017 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().remove(sessionId); 2018 2019 if (streams == null) 2020 return; 2021 2022 for (McpRequestResult.StreamOpened streamOpened : streams) 2023 streamOpened.terminate(); 2024 } 2025 2026 @NonNull 2027 @Override 2028 public Boolean publishSessionMessage(@NonNull String sessionId, 2029 @NonNull McpObject message) { 2030 requireNonNull(sessionId); 2031 requireNonNull(message); 2032 2033 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId); 2034 2035 if (streams == null || streams.isEmpty()) 2036 return false; 2037 2038 for (int i = streams.size() - 1; i >= 0; i--) { 2039 McpRequestResult.StreamOpened streamOpened = streams.get(i); 2040 2041 if (streamOpened.isClosed()) 2042 continue; 2043 2044 streamOpened.emitMessage(message); 2045 return true; 2046 } 2047 2048 return false; 2049 } 2050 2051 protected void onClientDisconnectedMcpStream(@Nullable BiConsumer<Request, String> clientDisconnectedMcpStreamHandler) { 2052 this.clientDisconnectedMcpStreamHandler.set(clientDisconnectedMcpStreamHandler); 2053 } 2054 2055 protected void closeOpenStream(@NonNull String sessionId, 2056 @NonNull Request request, 2057 McpRequestResult.StreamOpened streamOpened) { 2058 requireNonNull(sessionId); 2059 requireNonNull(request); 2060 requireNonNull(streamOpened); 2061 2062 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId); 2063 2064 if (streams != null) { 2065 streams.remove(streamOpened); 2066 2067 if (streams.isEmpty()) 2068 getOpenStreamsBySessionId().remove(sessionId, streams); 2069 } 2070 2071 BiConsumer<Request, String> handler = this.clientDisconnectedMcpStreamHandler.get(); 2072 2073 if (handler != null) 2074 handler.accept(request, sessionId); 2075 } 2076 2077 @NonNull 2078 protected ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> getOpenStreamsBySessionId() { 2079 return this.openStreamsBySessionId; 2080 } 2081 } 2082 2083 /** 2084 * Mock Server-Sent Event unicaster that doesn't touch the network at all, useful for testing. 2085 */ 2086 @ThreadSafe 2087 static class MockSseUnicaster implements SseUnicaster { 2088 @NonNull 2089 private final ResourcePath resourcePath; 2090 @NonNull 2091 private final Consumer<SseEvent> eventConsumer; 2092 @NonNull 2093 private final Consumer<SseComment> commentConsumer; 2094 @NonNull 2095 private final AtomicReference<Consumer<Throwable>> unicastErrorHandler; 2096 2097 public MockSseUnicaster(@NonNull ResourcePath resourcePath, 2098 @NonNull Consumer<SseEvent> eventConsumer, 2099 @NonNull Consumer<SseComment> commentConsumer, 2100 @NonNull AtomicReference<Consumer<Throwable>> unicastErrorHandler) { 2101 requireNonNull(resourcePath); 2102 requireNonNull(eventConsumer); 2103 requireNonNull(commentConsumer); 2104 requireNonNull(unicastErrorHandler); 2105 2106 this.resourcePath = resourcePath; 2107 this.eventConsumer = eventConsumer; 2108 this.commentConsumer = commentConsumer; 2109 this.unicastErrorHandler = unicastErrorHandler; 2110 } 2111 2112 @Override 2113 public void unicastEvent(@NonNull SseEvent sseEvent) { 2114 requireNonNull(sseEvent); 2115 try { 2116 getEventConsumer().accept(sseEvent); 2117 } catch (Throwable throwable) { 2118 handleUnicastError(throwable); 2119 } 2120 } 2121 2122 @Override 2123 public void unicastComment(@NonNull SseComment sseComment) { 2124 requireNonNull(sseComment); 2125 try { 2126 getCommentConsumer().accept(sseComment); 2127 } catch (Throwable throwable) { 2128 handleUnicastError(throwable); 2129 } 2130 } 2131 2132 @NonNull 2133 @Override 2134 public ResourcePath getResourcePath() { 2135 return this.resourcePath; 2136 } 2137 2138 @NonNull 2139 protected Consumer<SseEvent> getEventConsumer() { 2140 return this.eventConsumer; 2141 } 2142 2143 @NonNull 2144 protected Consumer<SseComment> getCommentConsumer() { 2145 return this.commentConsumer; 2146 } 2147 2148 protected void handleUnicastError(@NonNull Throwable throwable) { 2149 requireNonNull(throwable); 2150 Consumer<Throwable> handler = this.unicastErrorHandler.get(); 2151 2152 if (handler != null) { 2153 try { 2154 handler.accept(throwable); 2155 return; 2156 } catch (Throwable ignored) { 2157 // Fall through to default behavior 2158 } 2159 } 2160 2161 throwable.printStackTrace(); 2162 } 2163 } 2164 2165 /** 2166 * Mock Server-Sent Event broadcaster that doesn't touch the network at all, useful for testing. 2167 */ 2168 @ThreadSafe 2169 static class MockSseBroadcaster implements SseBroadcaster { 2170 // ConcurrentHashMap doesn't allow null values, so we use a sentinel if context is null 2171 private static final Object NULL_CONTEXT_SENTINEL; 2172 2173 static { 2174 NULL_CONTEXT_SENTINEL = new Object(); 2175 } 2176 2177 @NonNull 2178 private final ResourcePath resourcePath; 2179 // Maps the Consumer (Listener) to its Context object (e.g. Locale) 2180 @NonNull 2181 private final Map<@NonNull Consumer<SseEvent>, @NonNull Object> eventConsumers; 2182 // Same goes for comments 2183 @NonNull 2184 private final Map<@NonNull Consumer<SseComment>, @NonNull Object> commentConsumers; 2185 @NonNull 2186 private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler; 2187 2188 public MockSseBroadcaster(@NonNull ResourcePath resourcePath, 2189 @NonNull AtomicReference<Consumer<Throwable>> broadcastErrorHandler) { 2190 requireNonNull(resourcePath); 2191 requireNonNull(broadcastErrorHandler); 2192 2193 this.resourcePath = resourcePath; 2194 this.eventConsumers = new ConcurrentHashMap<>(); 2195 this.commentConsumers = new ConcurrentHashMap<>(); 2196 this.broadcastErrorHandler = broadcastErrorHandler; 2197 } 2198 2199 @NonNull 2200 @Override 2201 public ResourcePath getResourcePath() { 2202 return this.resourcePath; 2203 } 2204 2205 @NonNull 2206 @Override 2207 public Long getClientCount() { 2208 return Long.valueOf(getEventConsumers().size() + getCommentConsumers().size()); 2209 } 2210 2211 @Override 2212 public void broadcastEvent(@NonNull SseEvent sseEvent) { 2213 requireNonNull(sseEvent); 2214 2215 for (Consumer<SseEvent> eventConsumer : getEventConsumers().keySet()) { 2216 try { 2217 eventConsumer.accept(sseEvent); 2218 } catch (Throwable throwable) { 2219 handleBroadcastError(throwable); 2220 } 2221 } 2222 } 2223 2224 @Override 2225 public void broadcastComment(@NonNull SseComment sseComment) { 2226 requireNonNull(sseComment); 2227 2228 for (Consumer<SseComment> commentConsumer : getCommentConsumers().keySet()) { 2229 try { 2230 commentConsumer.accept(sseComment); 2231 } catch (Throwable throwable) { 2232 handleBroadcastError(throwable); 2233 } 2234 } 2235 } 2236 2237 @Override 2238 public <T> void broadcastEvent( 2239 @NonNull Function<Object, T> keySelector, 2240 @NonNull Function<T, SseEvent> eventProvider 2241 ) { 2242 requireNonNull(keySelector); 2243 requireNonNull(eventProvider); 2244 2245 // 1. Create a temporary cache for this specific broadcast operation. 2246 // This ensures we only run the expensive 'eventProvider' once per unique key. 2247 Map<T, SseEvent> payloadCache = new HashMap<>(); 2248 2249 this.getEventConsumers().forEach((consumer, context) -> { 2250 try { 2251 // 2. Derive the key from the subscriber's context 2252 T key = keySelector.apply(context); 2253 2254 // 3. Memoize: Generate the payload if we haven't seen this key yet, otherwise reuse it 2255 SseEvent event = payloadCache.computeIfAbsent(key, eventProvider); 2256 2257 // 4. Dispatch 2258 consumer.accept(event); 2259 } catch (Throwable throwable) { 2260 handleBroadcastError(throwable); 2261 } 2262 }); 2263 } 2264 2265 @Override 2266 public <T> void broadcastComment( 2267 @NonNull Function<Object, T> keySelector, 2268 @NonNull Function<T, SseComment> commentProvider 2269 ) { 2270 requireNonNull(keySelector); 2271 requireNonNull(commentProvider); 2272 2273 // 1. Create temporary cache 2274 Map<T, SseComment> commentCache = new HashMap<>(); 2275 2276 this.getCommentConsumers().forEach((consumer, context) -> { 2277 try { 2278 // 2. Derive key 2279 T key = keySelector.apply(context); 2280 2281 // 3. Memoize 2282 SseComment comment = commentCache.computeIfAbsent(key, commentProvider); 2283 2284 // 4. Dispatch 2285 consumer.accept(comment); 2286 } catch (Throwable throwable) { 2287 handleBroadcastError(throwable); 2288 } 2289 }); 2290 } 2291 2292 @NonNull 2293 public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) { 2294 return registerEventConsumer(eventConsumer, null); 2295 } 2296 2297 /** 2298 * Registers a consumer with an associated context, simulating a client with specific traits. 2299 */ 2300 @NonNull 2301 public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer, @Nullable Object context) { 2302 requireNonNull(eventConsumer); 2303 // map.put returns null if the key was new, which conceptually matches "add" returning true 2304 return this.getEventConsumers().put(eventConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null; 2305 } 2306 2307 @NonNull 2308 public Boolean unregisterEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) { 2309 requireNonNull(eventConsumer); 2310 return this.getEventConsumers().remove(eventConsumer) != null; 2311 } 2312 2313 @NonNull 2314 public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) { 2315 return registerCommentConsumer(commentConsumer, null); 2316 } 2317 2318 /** 2319 * Registers a consumer with an associated context, simulating a client with specific traits. 2320 */ 2321 @NonNull 2322 public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer, @Nullable Object context) { 2323 requireNonNull(commentConsumer); 2324 return this.getCommentConsumers().put(commentConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null; 2325 } 2326 2327 @NonNull 2328 public Boolean unregisterCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) { 2329 requireNonNull(commentConsumer); 2330 return this.getCommentConsumers().remove(commentConsumer) != null; 2331 } 2332 2333 @NonNull 2334 protected Map<@NonNull Consumer<SseEvent>, @NonNull Object> getEventConsumers() { 2335 return this.eventConsumers; 2336 } 2337 2338 @NonNull 2339 protected Map<@NonNull Consumer<SseComment>, @NonNull Object> getCommentConsumers() { 2340 return this.commentConsumers; 2341 } 2342 2343 protected void handleBroadcastError(@NonNull Throwable throwable) { 2344 requireNonNull(throwable); 2345 Consumer<Throwable> handler = this.broadcastErrorHandler.get(); 2346 2347 if (handler != null) { 2348 try { 2349 handler.accept(throwable); 2350 return; 2351 } catch (Throwable ignored) { 2352 // Fall through to default behavior 2353 } 2354 } 2355 2356 throwable.printStackTrace(); 2357 } 2358 } 2359 2360 /** 2361 * Mock Server-Sent Event server that doesn't touch the network at all, useful for testing. 2362 * 2363 * @author <a href="https://www.revetkn.com">Mark Allen</a> 2364 */ 2365 @ThreadSafe 2366 static class MockSseServer implements SseServer { 2367 @Nullable 2368 private SokletConfig sokletConfig; 2369 private SseServer.@Nullable RequestHandler requestHandler; 2370 @NonNull 2371 private final ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> broadcastersByResourcePath; 2372 @NonNull 2373 private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler; 2374 @NonNull 2375 private final AtomicReference<Consumer<Throwable>> unicastErrorHandler; 2376 2377 public MockSseServer() { 2378 this.broadcastersByResourcePath = new ConcurrentHashMap<>(); 2379 this.broadcastErrorHandler = new AtomicReference<>(); 2380 this.unicastErrorHandler = new AtomicReference<>(); 2381 } 2382 2383 @Override 2384 public void start() { 2385 // No-op 2386 } 2387 2388 @Override 2389 public void stop() { 2390 // No-op 2391 } 2392 2393 @NonNull 2394 @Override 2395 public Boolean isStarted() { 2396 return true; 2397 } 2398 2399 @NonNull 2400 @Override 2401 public Optional<? extends SseBroadcaster> acquireBroadcaster(@Nullable ResourcePath resourcePath) { 2402 if (resourcePath == null) 2403 return Optional.empty(); 2404 2405 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 2406 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 2407 2408 return Optional.of(broadcaster); 2409 } 2410 2411 public void registerEventConsumer(@NonNull ResourcePath resourcePath, 2412 @NonNull Consumer<SseEvent> eventConsumer) { 2413 registerEventConsumer(resourcePath, eventConsumer, null); 2414 } 2415 2416 public void registerEventConsumer(@NonNull ResourcePath resourcePath, 2417 @NonNull Consumer<SseEvent> eventConsumer, 2418 @Nullable Object context) { 2419 requireNonNull(resourcePath); 2420 requireNonNull(eventConsumer); 2421 2422 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 2423 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 2424 2425 broadcaster.registerEventConsumer(eventConsumer, context); 2426 } 2427 2428 @NonNull 2429 public Boolean unregisterEventConsumer(@NonNull ResourcePath resourcePath, 2430 @NonNull Consumer<SseEvent> eventConsumer) { 2431 requireNonNull(resourcePath); 2432 requireNonNull(eventConsumer); 2433 2434 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath); 2435 2436 if (broadcaster == null) 2437 return false; 2438 2439 return broadcaster.unregisterEventConsumer(eventConsumer); 2440 } 2441 2442 public void registerCommentConsumer(@NonNull ResourcePath resourcePath, 2443 @NonNull Consumer<SseComment> commentConsumer) { 2444 registerCommentConsumer(resourcePath, commentConsumer, null); 2445 } 2446 2447 public void registerCommentConsumer(@NonNull ResourcePath resourcePath, 2448 @NonNull Consumer<SseComment> commentConsumer, 2449 @Nullable Object context) { 2450 requireNonNull(resourcePath); 2451 requireNonNull(commentConsumer); 2452 2453 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 2454 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 2455 2456 broadcaster.registerCommentConsumer(commentConsumer, context); 2457 } 2458 2459 @NonNull 2460 public Boolean unregisterCommentConsumer(@NonNull ResourcePath resourcePath, 2461 @NonNull Consumer<SseComment> commentConsumer) { 2462 requireNonNull(resourcePath); 2463 requireNonNull(commentConsumer); 2464 2465 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath); 2466 2467 if (broadcaster == null) 2468 return false; 2469 2470 return broadcaster.unregisterCommentConsumer(commentConsumer); 2471 } 2472 2473 @Override 2474 public void initialize(@NonNull SokletConfig sokletConfig, 2475 SseServer.@NonNull RequestHandler requestHandler) { 2476 requireNonNull(sokletConfig); 2477 requireNonNull(requestHandler); 2478 2479 this.sokletConfig = sokletConfig; 2480 this.requestHandler = requestHandler; 2481 } 2482 2483 public void onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) { 2484 this.broadcastErrorHandler.set(onBroadcastError); 2485 } 2486 2487 public void onUnicastError(@Nullable Consumer<Throwable> onUnicastError) { 2488 this.unicastErrorHandler.set(onUnicastError); 2489 } 2490 2491 @NonNull 2492 protected Optional<SokletConfig> getSokletConfig() { 2493 return Optional.ofNullable(this.sokletConfig); 2494 } 2495 2496 @NonNull 2497 protected Optional<SseServer.RequestHandler> getRequestHandler() { 2498 return Optional.ofNullable(this.requestHandler); 2499 } 2500 2501 @NonNull 2502 protected ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> getBroadcastersByResourcePath() { 2503 return this.broadcastersByResourcePath; 2504 } 2505 2506 @NonNull 2507 protected AtomicReference<Consumer<Throwable>> getUnicastErrorHandler() { 2508 return this.unicastErrorHandler; 2509 } 2510 } 2511 2512 /** 2513 * Efficiently provides the current time in RFC 1123 HTTP-date format. 2514 * Updates once per second to avoid the overhead of formatting dates on every request. 2515 */ 2516 @ThreadSafe 2517 private static final class CachedHttpDate { 2518 @NonNull 2519 private static final DateTimeFormatter FORMATTER = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss z", Locale.ENGLISH) 2520 .withZone(ZoneId.of("GMT")); 2521 @NonNull 2522 private static volatile String CURRENT_VALUE = FORMATTER.format(Instant.now()); 2523 2524 static { 2525 Thread t = new Thread(() -> { 2526 while (true) { 2527 try { 2528 Thread.sleep(1000); 2529 CURRENT_VALUE = FORMATTER.format(Instant.now()); 2530 } catch (InterruptedException e) { 2531 break; // Allow thread to die on JVM shutdown 2532 } 2533 } 2534 }, "soklet-date-header-value-updater"); 2535 t.setDaemon(true); 2536 t.start(); 2537 } 2538 2539 @NonNull 2540 static String getCurrentValue() { 2541 return CURRENT_VALUE; 2542 } 2543 } 2544}