001/* 002 * Copyright 2022-2026 Revetware LLC. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.soklet; 018 019import com.soklet.SseRequestResult.HandshakeAccepted; 020import com.soklet.SseRequestResult.HandshakeRejected; 021import com.soklet.annotation.SseEventSource; 022import com.soklet.internal.spring.LinkedCaseInsensitiveMap; 023import org.jspecify.annotations.NonNull; 024import org.jspecify.annotations.Nullable; 025 026import javax.annotation.concurrent.NotThreadSafe; 027import javax.annotation.concurrent.ThreadSafe; 028import java.io.BufferedReader; 029import java.io.ByteArrayOutputStream; 030import java.io.IOException; 031import java.io.InputStreamReader; 032import java.io.Reader; 033import java.lang.reflect.InvocationTargetException; 034import java.nio.ByteBuffer; 035import java.nio.CharBuffer; 036import java.nio.charset.Charset; 037import java.nio.charset.CharsetEncoder; 038import java.nio.charset.CharacterCodingException; 039import java.nio.charset.CoderResult; 040import java.nio.charset.StandardCharsets; 041import java.time.Duration; 042import java.time.Instant; 043import java.util.ArrayList; 044import java.util.Collections; 045import java.util.EnumSet; 046import java.util.HashMap; 047import java.util.LinkedHashMap; 048import java.util.List; 049import java.util.Locale; 050import java.util.Map; 051import java.util.Map.Entry; 052import java.util.Objects; 053import java.util.Optional; 054import java.util.Set; 055import java.util.concurrent.ConcurrentHashMap; 056import java.util.concurrent.CopyOnWriteArrayList; 057import java.util.concurrent.CopyOnWriteArraySet; 058import java.util.concurrent.CountDownLatch; 059import java.util.concurrent.Flow; 060import java.util.concurrent.TimeUnit; 061import java.util.concurrent.atomic.AtomicBoolean; 062import java.util.concurrent.atomic.AtomicReference; 063import java.util.concurrent.locks.ReentrantLock; 064import java.util.function.BiConsumer; 065import java.util.function.Consumer; 066import java.util.function.Function; 067import java.util.stream.Collectors; 068 069import static com.soklet.Utilities.emptyByteArray; 070import static com.soklet.Utilities.extractContentTypeFromHeaders; 071import static java.lang.String.format; 072import static java.util.Objects.requireNonNull; 073 074/** 075 * Soklet's main class - manages one or more configured transport servers ({@link HttpServer}, {@link SseServer}, and/or {@link McpServer}) 076 * using the provided system configuration. 077 * <p> 078 * <pre>{@code // Use out-of-the-box defaults 079 * SokletConfig config = SokletConfig.withHttpServer( 080 * HttpServer.fromPort(8080) 081 * ).build(); 082 * 083 * try (Soklet soklet = Soklet.fromConfig(config)) { 084 * soklet.start(); 085 * System.out.println("Soklet started, press [enter] to exit"); 086 * soklet.awaitShutdown(ShutdownTrigger.ENTER_KEY); 087 * }}</pre> 088 * <p> 089 * Soklet also offers an off-network {@link Simulator} concept via {@link #runSimulator(SokletConfig, Consumer)}, useful for integration testing. 090 * <p> 091 * Given a <em>Resource Method</em>... 092 * <pre>{@code public class HelloResource { 093 * @GET("/hello") 094 * public String hello(@QueryParameter String name) { 095 * return String.format("Hello, %s", name); 096 * } 097 * }}</pre> 098 * ...we might test it like this: 099 * <pre>{@code @Test 100 * public void integrationTest() { 101 * // Just use your app's existing configuration 102 * SokletConfig config = obtainMySokletConfig(); 103 * 104 * // Instead of running on a real HTTP server that listens on a port, 105 * // a non-network Simulator is provided against which you can 106 * // issue requests and receive responses. 107 * Soklet.runSimulator(config, (simulator -> { 108 * // Construct a request 109 * Request request = Request.withPath(HttpMethod.GET, "/hello") 110 * .queryParameters(Map.of("name", Set.of("Mark"))) 111 * .build(); 112 * 113 * // Perform the request and get a handle to the response 114 * HttpRequestResult result = simulator.performHttpRequest(request); 115 * MarshaledResponse marshaledResponse = result.getMarshaledResponse(); 116 * 117 * // Verify status code 118 * Integer expectedCode = 200; 119 * Integer actualCode = marshaledResponse.getStatusCode(); 120 * assertEquals(expectedCode, actualCode, "Bad status code"); 121 * 122 * // Verify response body 123 * marshaledResponse.getBody().ifPresentOrElse(body -> { 124 * String expectedBody = "Hello, Mark"; 125 * byte[] bytes = ((MarshaledResponseBody.Bytes) body).getBytes(); 126 * String actualBody = new String(bytes, StandardCharsets.UTF_8); 127 * assertEquals(expectedBody, actualBody, "Bad response body"); 128 * }, () -> { 129 * Assertions.fail("No response body"); 130 * }); 131 * })); 132 * }}</pre> 133 * <p> 134 * The {@link Simulator} also supports Server-Sent Events. 135 * <p> 136 * Integration testing documentation is available at <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a>. 137 * 138 * @author <a href="https://www.revetkn.com">Mark Allen</a> 139 */ 140@ThreadSafe 141public final class Soklet implements AutoCloseable { 142 @NonNull 143 private static final Map<@NonNull String, @NonNull Set<@NonNull String>> DEFAULT_ACCEPTED_HANDSHAKE_HEADERS; 144 145 static { 146 // Generally speaking, we always want these headers for SSE streaming responses. 147 // Users can override if they think necessary 148 LinkedCaseInsensitiveMap<Set<String>> defaultAcceptedHandshakeHeaders = new LinkedCaseInsensitiveMap<>(4); 149 defaultAcceptedHandshakeHeaders.put("Content-Type", Set.of("text/event-stream; charset=UTF-8")); 150 defaultAcceptedHandshakeHeaders.put("Cache-Control", Set.of("no-cache", "no-transform")); 151 defaultAcceptedHandshakeHeaders.put("Connection", Set.of("keep-alive")); 152 defaultAcceptedHandshakeHeaders.put("X-Accel-Buffering", Set.of("no")); 153 154 DEFAULT_ACCEPTED_HANDSHAKE_HEADERS = Collections.unmodifiableMap(defaultAcceptedHandshakeHeaders); 155 } 156 157 /** 158 * Acquires a Soklet instance with the given configuration. 159 * 160 * @param sokletConfig configuration that drives the Soklet system 161 * @return a Soklet instance 162 */ 163 @NonNull 164 public static Soklet fromConfig(@NonNull SokletConfig sokletConfig) { 165 requireNonNull(sokletConfig); 166 return new Soklet(sokletConfig); 167 } 168 169 @NonNull 170 private final SokletConfig sokletConfig; 171 @NonNull 172 private final ReentrantLock lock; 173 @NonNull 174 private final AtomicReference<CountDownLatch> awaitShutdownLatchReference; 175 @NonNull 176 private final DefaultMcpRuntime defaultMcpRuntime; 177 178 /** 179 * Creates a Soklet instance with the given configuration. 180 * 181 * @param sokletConfig configuration that drives the Soklet system 182 */ 183 private Soklet(@NonNull SokletConfig sokletConfig) { 184 requireNonNull(sokletConfig); 185 186 this.sokletConfig = sokletConfig; 187 this.lock = new ReentrantLock(); 188 this.awaitShutdownLatchReference = new AtomicReference<>(new CountDownLatch(1)); 189 this.defaultMcpRuntime = new DefaultMcpRuntime(this); 190 191 sokletConfig.getMcpServer() 192 .map(McpServer::getSessionStore) 193 .filter(DefaultMcpSessionStore.class::isInstance) 194 .map(DefaultMcpSessionStore.class::cast) 195 .ifPresent(sessionStore -> sessionStore.pinnedSessionPredicate(this.defaultMcpRuntime::hasActiveStream)); 196 197 sokletConfig.getMcpServer() 198 .map(mcpServer -> mcpServer instanceof McpServerProxy mcpServerProxy ? mcpServerProxy.getRealImplementation() : mcpServer) 199 .filter(DefaultMcpServer.class::isInstance) 200 .map(DefaultMcpServer.class::cast) 201 .ifPresent(defaultMcpServer -> defaultMcpServer.mcpRuntime(this.defaultMcpRuntime)); 202 203 Set<ResourceMethod> resourceMethods = sokletConfig.getResourceMethodResolver().getResourceMethods(); 204 205 // Fail fast in the event that Soklet appears misconfigured 206 if (resourceMethods.size() == 0 207 && sokletConfig.getMcpServer().isEmpty()) 208 throw new IllegalStateException(format("No Soklet Resource Methods were found. First, try to rebuild and see if that solves the problem. If not, please ensure your %s is configured correctly. " 209 + "See https://www.soklet.com/docs/request-handling#resource-method-resolution for details.", ResourceMethodResolver.class.getSimpleName())); 210 211 boolean hasStandardHttpResourceMethods = resourceMethods.stream() 212 .anyMatch(resourceMethod -> !resourceMethod.isSseEventSource()); 213 214 if (hasStandardHttpResourceMethods && sokletConfig.getHttpServer().isEmpty()) 215 throw new IllegalStateException(format("Resource Methods were found, but no %s is configured. See https://www.soklet.com/docs/server-configuration for details.", 216 HttpServer.class.getSimpleName())); 217 218 // SSE misconfiguration check: @SseEventSource resource methods are declared, but not SseServer exists 219 boolean hasSseResourceMethods = resourceMethods.stream() 220 .anyMatch(resourceMethod -> resourceMethod.isSseEventSource()); 221 222 if (hasSseResourceMethods && sokletConfig.getSseServer().isEmpty()) 223 throw new IllegalStateException(format("Resource Methods annotated with @%s were found, but no %s is configured. See https://www.soklet.com/docs/server-sent-events for details.", 224 SseEventSource.class.getSimpleName(), SseServer.class.getSimpleName())); 225 226 MetricsCollector metricsCollector = sokletConfig.getMetricsCollector(); 227 228 if (metricsCollector instanceof DefaultMetricsCollector defaultMetricsCollector) { 229 try { 230 defaultMetricsCollector.initialize(sokletConfig); 231 } catch (Throwable t) { 232 sokletConfig.getAggregateLifecycleObserver().didReceiveLogEvent( 233 LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED, 234 format("An exception occurred while initializing %s", metricsCollector.getClass().getSimpleName())) 235 .throwable(t) 236 .build()); 237 } 238 } 239 240 // Use a layer of indirection here so the Soklet type does not need to directly implement the `RequestHandler` interface. 241 // Reasoning: the `handleRequest` method for Soklet should not be public, which might lead to accidental invocation by users. 242 // That method should only be called by the managed `HttpServer` instance. 243 Soklet soklet = this; 244 245 sokletConfig.getHttpServer().ifPresent(server -> server.initialize(getSokletConfig(), (request, marshaledResponseConsumer) -> { 246 // Delegate to Soklet's internal request handling method 247 soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer); 248 })); 249 250 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 251 252 if (sseServer != null) 253 sseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 254 // Delegate to Soklet's internal request handling method 255 soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer); 256 }); 257 258 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 259 260 if (mcpServer != null) 261 mcpServer.initialize(sokletConfig, soklet::handleMcpRequest); 262 } 263 264 /** 265 * Starts the managed server instance[s]. 266 * <p> 267 * If the managed server[s] are already started, this is a no-op. 268 */ 269 public void start() { 270 getLock().lock(); 271 272 try { 273 if (isStarted()) 274 return; 275 276 getAwaitShutdownLatchReference().set(new CountDownLatch(1)); 277 278 SokletConfig sokletConfig = getSokletConfig(); 279 LifecycleObserver lifecycleObserver = sokletConfig.getAggregateLifecycleObserver(); 280 281 // 1. Notify global intent to start 282 lifecycleObserver.willStartSoklet(this); 283 284 HttpServer httpServer = sokletConfig.getHttpServer().orElse(null); 285 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 286 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 287 boolean httpServerStarted = false; 288 boolean sseServerStarted = false; 289 boolean mcpServerStarted = false; 290 291 try { 292 // 2. Attempt to start Main HttpServer 293 if (httpServer != null) { 294 lifecycleObserver.willStartHttpServer(httpServer); 295 try { 296 httpServer.start(); 297 httpServerStarted = true; 298 lifecycleObserver.didStartHttpServer(httpServer); 299 } catch (Throwable t) { 300 lifecycleObserver.didFailToStartHttpServer(httpServer, t); 301 throw t; // Rethrow to trigger outer catch block 302 } 303 } 304 305 // 3. Attempt to start SSE HttpServer (if present) 306 if (sseServer != null) { 307 lifecycleObserver.willStartSseServer(sseServer); 308 try { 309 sseServer.start(); 310 sseServerStarted = true; 311 lifecycleObserver.didStartSseServer(sseServer); 312 } catch (Throwable t) { 313 lifecycleObserver.didFailToStartSseServer(sseServer, t); 314 throw t; // Rethrow to trigger outer catch block 315 } 316 } 317 318 if (mcpServer != null) { 319 lifecycleObserver.willStartMcpServer(mcpServer); 320 try { 321 mcpServer.start(); 322 mcpServerStarted = true; 323 lifecycleObserver.didStartMcpServer(mcpServer); 324 } catch (Throwable t) { 325 lifecycleObserver.didFailToStartMcpServer(mcpServer, t); 326 throw t; 327 } 328 } 329 330 // 4. Global success 331 lifecycleObserver.didStartSoklet(this); 332 } catch (Throwable t) { 333 rollbackStartedServersAfterFailedStart(lifecycleObserver, 334 httpServer, httpServerStarted, 335 sseServer, sseServerStarted, 336 mcpServer, mcpServerStarted, 337 t); 338 339 // 5. Global failure 340 lifecycleObserver.didFailToStartSoklet(this, t); 341 342 // Ensure the exception bubbles up so the application knows startup failed 343 if (t instanceof RuntimeException) 344 throw (RuntimeException) t; 345 346 throw new RuntimeException(t); 347 } 348 } finally { 349 getLock().unlock(); 350 } 351 } 352 353 private void rollbackStartedServersAfterFailedStart(@NonNull LifecycleObserver lifecycleObserver, 354 @Nullable HttpServer httpServer, 355 boolean httpServerStarted, 356 @Nullable SseServer sseServer, 357 boolean sseServerStarted, 358 @Nullable McpServer mcpServer, 359 boolean mcpServerStarted, 360 @NonNull Throwable startupFailure) { 361 requireNonNull(lifecycleObserver); 362 requireNonNull(startupFailure); 363 364 if (mcpServerStarted && mcpServer != null) 365 stopStartedMcpServerForRollback(lifecycleObserver, mcpServer, startupFailure); 366 367 if (sseServerStarted && sseServer != null) 368 stopStartedSseServerForRollback(lifecycleObserver, sseServer, startupFailure); 369 370 if (httpServerStarted && httpServer != null) 371 stopStartedHttpServerForRollback(lifecycleObserver, httpServer, startupFailure); 372 373 CountDownLatch awaitShutdownLatch = getAwaitShutdownLatchReference().get(); 374 375 if (awaitShutdownLatch != null) 376 awaitShutdownLatch.countDown(); 377 } 378 379 private void stopStartedHttpServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 380 @NonNull HttpServer httpServer, 381 @NonNull Throwable startupFailure) { 382 requireNonNull(lifecycleObserver); 383 requireNonNull(httpServer); 384 requireNonNull(startupFailure); 385 386 try { 387 lifecycleObserver.willStopHttpServer(httpServer); 388 } catch (Throwable t) { 389 startupFailure.addSuppressed(t); 390 } 391 392 try { 393 httpServer.stop(); 394 try { 395 lifecycleObserver.didStopHttpServer(httpServer); 396 } catch (Throwable t) { 397 startupFailure.addSuppressed(t); 398 } 399 } catch (Throwable t) { 400 startupFailure.addSuppressed(t); 401 402 try { 403 lifecycleObserver.didFailToStopHttpServer(httpServer, t); 404 } catch (Throwable t2) { 405 startupFailure.addSuppressed(t2); 406 } 407 } 408 } 409 410 private void stopStartedSseServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 411 @NonNull SseServer sseServer, 412 @NonNull Throwable startupFailure) { 413 requireNonNull(lifecycleObserver); 414 requireNonNull(sseServer); 415 requireNonNull(startupFailure); 416 417 try { 418 lifecycleObserver.willStopSseServer(sseServer); 419 } catch (Throwable t) { 420 startupFailure.addSuppressed(t); 421 } 422 423 try { 424 sseServer.stop(); 425 try { 426 lifecycleObserver.didStopSseServer(sseServer); 427 } catch (Throwable t) { 428 startupFailure.addSuppressed(t); 429 } 430 } catch (Throwable t) { 431 startupFailure.addSuppressed(t); 432 433 try { 434 lifecycleObserver.didFailToStopSseServer(sseServer, t); 435 } catch (Throwable t2) { 436 startupFailure.addSuppressed(t2); 437 } 438 } 439 } 440 441 private void stopStartedMcpServerForRollback(@NonNull LifecycleObserver lifecycleObserver, 442 @NonNull McpServer mcpServer, 443 @NonNull Throwable startupFailure) { 444 requireNonNull(lifecycleObserver); 445 requireNonNull(mcpServer); 446 requireNonNull(startupFailure); 447 448 try { 449 lifecycleObserver.willStopMcpServer(mcpServer); 450 } catch (Throwable t) { 451 startupFailure.addSuppressed(t); 452 } 453 454 try { 455 mcpServer.stop(); 456 try { 457 lifecycleObserver.didStopMcpServer(mcpServer); 458 } catch (Throwable t) { 459 startupFailure.addSuppressed(t); 460 } 461 } catch (Throwable t) { 462 startupFailure.addSuppressed(t); 463 464 try { 465 lifecycleObserver.didFailToStopMcpServer(mcpServer, t); 466 } catch (Throwable t2) { 467 startupFailure.addSuppressed(t2); 468 } 469 } 470 } 471 472 /** 473 * Stops the managed server instance[s]. 474 * <p> 475 * If the managed server[s] are already stopped, this is a no-op. 476 */ 477 public void stop() { 478 getLock().lock(); 479 480 try { 481 if (isStarted()) { 482 SokletConfig sokletConfig = getSokletConfig(); 483 LifecycleObserver lifecycleObserver = sokletConfig.getAggregateLifecycleObserver(); 484 485 // 1. Notify global intent to stop 486 lifecycleObserver.willStopSoklet(this); 487 488 Throwable firstEncounteredException = null; 489 HttpServer httpServer = sokletConfig.getHttpServer().orElse(null); 490 491 // 2. Attempt to stop Main HttpServer 492 if (httpServer != null && httpServer.isStarted()) { 493 lifecycleObserver.willStopHttpServer(httpServer); 494 try { 495 httpServer.stop(); 496 lifecycleObserver.didStopHttpServer(httpServer); 497 } catch (Throwable t) { 498 firstEncounteredException = t; 499 lifecycleObserver.didFailToStopHttpServer(httpServer, t); 500 } 501 } 502 503 // 3. Attempt to stop SSE HttpServer 504 SseServer sseServer = sokletConfig.getSseServer().orElse(null); 505 506 if (sseServer != null && sseServer.isStarted()) { 507 lifecycleObserver.willStopSseServer(sseServer); 508 try { 509 sseServer.stop(); 510 lifecycleObserver.didStopSseServer(sseServer); 511 } catch (Throwable t) { 512 if (firstEncounteredException == null) 513 firstEncounteredException = t; 514 515 lifecycleObserver.didFailToStopSseServer(sseServer, t); 516 } 517 } 518 519 McpServer mcpServer = sokletConfig.getMcpServer().orElse(null); 520 521 if (mcpServer != null && mcpServer.isStarted()) { 522 lifecycleObserver.willStopMcpServer(mcpServer); 523 try { 524 mcpServer.stop(); 525 lifecycleObserver.didStopMcpServer(mcpServer); 526 } catch (Throwable t) { 527 if (firstEncounteredException == null) 528 firstEncounteredException = t; 529 530 lifecycleObserver.didFailToStopMcpServer(mcpServer, t); 531 } 532 } 533 534 // 4. Global completion (Success or Failure) 535 if (firstEncounteredException == null) 536 lifecycleObserver.didStopSoklet(this); 537 else 538 lifecycleObserver.didFailToStopSoklet(this, firstEncounteredException); 539 } 540 } finally { 541 try { 542 getAwaitShutdownLatchReference().get().countDown(); 543 } finally { 544 getLock().unlock(); 545 } 546 } 547 } 548 549 /** 550 * Blocks the current thread until JVM shutdown ({@code SIGTERM/SIGINT/System.exit(...)} and so forth), <strong>or</strong> if one of the provided {@code shutdownTriggers} occurs. 551 * <p> 552 * This method will automatically invoke this instance's {@link #stop()} method once it becomes unblocked. 553 * <p> 554 * <strong>Notes regarding {@link ShutdownTrigger#ENTER_KEY}:</strong> 555 * <ul> 556 * <li>It will invoke {@link #stop()} on <i>all</i> Soklet instances, as stdin is process-wide</li> 557 * <li>It is only supported for environments with an interactive TTY and will be ignored if none exists (e.g. running in a Docker container) - Soklet will detect this and fire {@link LifecycleObserver#didReceiveLogEvent(LogEvent)} with an event of type {@link LogEventType#CONFIGURATION_UNSUPPORTED}</li> 558 * </ul> 559 * 560 * @param shutdownTriggers additional trigger[s] which signal that shutdown should occur, e.g. {@link ShutdownTrigger#ENTER_KEY} for "enter key pressed" 561 * @throws InterruptedException if the current thread has its interrupted status set on entry to this method, or is interrupted while waiting 562 */ 563 public void awaitShutdown(@Nullable ShutdownTrigger... shutdownTriggers) throws InterruptedException { 564 Thread shutdownHook = null; 565 boolean registeredEnterKeyShutdownTrigger = false; 566 Set<ShutdownTrigger> shutdownTriggersAsSet = shutdownTriggers == null || shutdownTriggers.length == 0 ? Set.of() : EnumSet.copyOf(Set.of(shutdownTriggers)); 567 568 try { 569 // Optionally listen for enter key 570 if (shutdownTriggersAsSet.contains(ShutdownTrigger.ENTER_KEY)) { 571 registeredEnterKeyShutdownTrigger = KeypressManager.register(this); // returns false if stdin unusable/disabled 572 573 if (!registeredEnterKeyShutdownTrigger) { 574 LogEvent logEvent = LogEvent.with( 575 LogEventType.CONFIGURATION_UNSUPPORTED, 576 format("Ignoring request for %s.%s - it is unsupported in this environment (no interactive TTY detected)", ShutdownTrigger.class.getSimpleName(), ShutdownTrigger.ENTER_KEY.name()) 577 ).build(); 578 579 getSokletConfig().getAggregateLifecycleObserver().didReceiveLogEvent(logEvent); 580 } 581 } 582 583 // Always register a shutdown hook 584 shutdownHook = new Thread(() -> { 585 try { 586 stop(); 587 } catch (Throwable ignored) { 588 // Nothing to do 589 } 590 }, "soklet-shutdown-hook"); 591 592 Runtime.getRuntime().addShutdownHook(shutdownHook); 593 594 // Wait until "close" finishes 595 getAwaitShutdownLatchReference().get().await(); 596 } finally { 597 if (registeredEnterKeyShutdownTrigger) 598 KeypressManager.unregister(this); 599 600 try { 601 Runtime.getRuntime().removeShutdownHook(shutdownHook); 602 } catch (IllegalStateException ignored) { 603 // JVM shutting down 604 } 605 } 606 } 607 608 /** 609 * Handles "awaitShutdown" for {@link ShutdownTrigger#ENTER_KEY} by listening to stdin - all Soklet instances are terminated on keypress. 610 */ 611 @ThreadSafe 612 private static final class KeypressManager { 613 @NonNull 614 private static final Set<@NonNull Soklet> SOKLET_REGISTRY; 615 @NonNull 616 private static final AtomicBoolean LISTENER_STARTED; 617 618 static { 619 SOKLET_REGISTRY = new CopyOnWriteArraySet<>(); 620 LISTENER_STARTED = new AtomicBoolean(false); 621 } 622 623 /** 624 * Register a Soklet for Enter-to-stop support. Returns true iff a listener is (or was already) active. 625 * If System.in is not usable (or disabled), returns false and does nothing. 626 */ 627 @NonNull 628 synchronized static Boolean register(@NonNull Soklet soklet) { 629 requireNonNull(soklet); 630 631 // If stdin is not readable (e.g., container with no TTY), don't start a listener. 632 if (!canReadFromStdin()) 633 return false; 634 635 SOKLET_REGISTRY.add(soklet); 636 637 // Start a single process-wide listener once. 638 if (LISTENER_STARTED.compareAndSet(false, true)) { 639 Thread thread = new Thread(KeypressManager::runLoop, "soklet-keypress-shutdown-listener"); 640 thread.setDaemon(true); // never block JVM exit 641 thread.start(); 642 } 643 644 return true; 645 } 646 647 synchronized static void unregister(@NonNull Soklet soklet) { 648 SOKLET_REGISTRY.remove(soklet); 649 // We intentionally keep the listener alive; it's daemon and cheap. 650 // If stdin hits EOF, the listener exits on its own. 651 } 652 653 /** 654 * Heuristic: if System.in is present and calling available() doesn't throw, 655 * treat it as readable. Works even in IDEs where System.console() is null. 656 */ 657 @NonNull 658 private static Boolean canReadFromStdin() { 659 if (System.in == null) 660 return false; 661 662 try { 663 // available() >= 0 means stream is open; 0 means no buffered data (that’s fine). 664 return System.in.available() >= 0; 665 } catch (IOException e) { 666 return false; 667 } 668 } 669 670 /** 671 * Single blocking read on stdin. On any line (or EOF), stop all registered servers. 672 */ 673 private static void runLoop() { 674 try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8))) { 675 // Blocks until newline or EOF; EOF (null) happens with /dev/null or closed pipe. 676 bufferedReader.readLine(); 677 stopAllSoklets(); 678 } catch (Throwable ignored) { 679 // If stdin is closed mid-run, just exit quietly. 680 } 681 } 682 683 synchronized private static void stopAllSoklets() { 684 // Either a line or EOF → stop everything that’s currently registered. 685 for (Soklet soklet : SOKLET_REGISTRY) { 686 try { 687 soklet.stop(); 688 } catch (Throwable ignored) { 689 // Nothing to do 690 } 691 } 692 } 693 694 private KeypressManager() {} 695 } 696 697 /** 698 * Nonpublic "informal" implementation of {@link com.soklet.HttpServer.RequestHandler} so Soklet does not need to expose {@code handleRequest} publicly. 699 * Reasoning: users of this library should never call {@code handleRequest} directly - it should only be invoked in response to events 700 * provided by a {@link HttpServer} or {@link SseServer} implementation. 701 */ 702 protected void handleRequest(@NonNull Request request, 703 @NonNull ServerType serverType, 704 @NonNull Consumer<HttpRequestResult> requestResultConsumer) { 705 requireNonNull(request); 706 requireNonNull(serverType); 707 requireNonNull(requestResultConsumer); 708 709 Instant processingStarted = Instant.now(); 710 711 SokletConfig sokletConfig = getSokletConfig(); 712 ResourceMethodResolver resourceMethodResolver = sokletConfig.getResourceMethodResolver(); 713 ResponseMarshaler responseMarshaler = sokletConfig.getResponseMarshaler(); 714 LifecycleObserver lifecycleObserver = sokletConfig.getAggregateLifecycleObserver(); 715 RequestInterceptor requestInterceptor = sokletConfig.getRequestInterceptor(); 716 MetricsCollector metricsCollector = sokletConfig.getMetricsCollector(); 717 718 // Holders to permit mutable effectively-final variables 719 AtomicReference<MarshaledResponse> marshaledResponseHolder = new AtomicReference<>(); 720 AtomicReference<Throwable> resourceMethodResolutionExceptionHolder = new AtomicReference<>(); 721 AtomicReference<Request> requestHolder = new AtomicReference<>(request); 722 AtomicReference<ResourceMethod> resourceMethodHolder = new AtomicReference<>(); 723 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 724 725 // Holders to permit mutable effectively-final state tracking 726 AtomicBoolean willStartResponseWritingCompleted = new AtomicBoolean(false); 727 AtomicBoolean didFinishResponseWritingCompleted = new AtomicBoolean(false); 728 AtomicBoolean didFinishRequestHandlingCompleted = new AtomicBoolean(false); 729 AtomicBoolean didInvokeWrapRequestConsumer = new AtomicBoolean(false); 730 731 List<Throwable> throwables = new ArrayList<>(10); 732 733 Consumer<LogEvent> safelyLog = (logEvent -> { 734 try { 735 lifecycleObserver.didReceiveLogEvent(logEvent); 736 } catch (Throwable throwable) { 737 throwable.printStackTrace(); 738 throwables.add(throwable); 739 } 740 }); 741 742 BiConsumer<String, Consumer<MetricsCollector>> safelyCollectMetrics = (message, metricsInvocation) -> { 743 if (metricsCollector == null) 744 return; 745 746 try { 747 metricsInvocation.accept(metricsCollector); 748 } catch (Throwable throwable) { 749 safelyLog.accept(LogEvent.with(LogEventType.METRICS_COLLECTOR_FAILED, message) 750 .throwable(throwable) 751 .request(requestHolder.get()) 752 .resourceMethod(resourceMethodHolder.get()) 753 .marshaledResponse(marshaledResponseHolder.get()) 754 .build()); 755 } 756 }; 757 758 requestHolder.set(request); 759 760 try { 761 requestInterceptor.wrapRequest(serverType, request, (wrappedRequest) -> { 762 didInvokeWrapRequestConsumer.set(true); 763 requestHolder.set(wrappedRequest); 764 765 try { 766 // Resolve after wrapping so path/method rewrites affect routing. 767 resourceMethodHolder.set(resourceMethodResolver.resourceMethodForRequest(requestHolder.get(), serverType).orElse(null)); 768 resourceMethodResolutionExceptionHolder.set(null); 769 } catch (Throwable t) { 770 safelyLog.accept(LogEvent.with(LogEventType.RESOURCE_METHOD_RESOLUTION_FAILED, "Unable to resolve Resource Method") 771 .throwable(t) 772 .request(requestHolder.get()) 773 .build()); 774 775 // If an exception occurs here, keep track of it - we will surface them after letting LifecycleObserver 776 // see that a request has come in. 777 throwables.add(t); 778 resourceMethodResolutionExceptionHolder.set(t); 779 resourceMethodHolder.set(null); 780 } 781 782 try { 783 lifecycleObserver.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get()); 784 } catch (Throwable t) { 785 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_START_REQUEST_HANDLING_FAILED, 786 format("An exception occurred while invoking %s::didStartRequestHandling", 787 LifecycleObserver.class.getSimpleName())) 788 .throwable(t) 789 .request(requestHolder.get()) 790 .resourceMethod(resourceMethodHolder.get()) 791 .build()); 792 793 throwables.add(t); 794 } 795 796 safelyCollectMetrics.accept( 797 format("An exception occurred while invoking %s::didStartRequestHandling", MetricsCollector.class.getSimpleName()), 798 (metricsInvocation) -> metricsInvocation.didStartRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get())); 799 800 try { 801 AtomicBoolean didInvokeMarshaledResponseConsumer = new AtomicBoolean(false); 802 803 requestInterceptor.interceptRequest(serverType, requestHolder.get(), resourceMethodHolder.get(), (interceptorRequest) -> { 804 requestHolder.set(interceptorRequest); 805 806 try { 807 if (resourceMethodResolutionExceptionHolder.get() != null) 808 throw resourceMethodResolutionExceptionHolder.get(); 809 810 HttpRequestResult requestResult = toHttpRequestResult(requestHolder.get(), resourceMethodHolder.get(), serverType); 811 requestResultHolder.set(requestResult); 812 813 MarshaledResponse originalMarshaledResponse = requestResult.getMarshaledResponse(); 814 MarshaledResponse updatedMarshaledResponse = requestResult.getMarshaledResponse(); 815 816 // A few special cases that are "global" in that they can affect all requests and 817 // need to happen after marshaling the response... 818 819 // 1. Customize response for HEAD (e.g. remove body, set Content-Length header) 820 updatedMarshaledResponse = applyHeadResponseIfApplicable(requestHolder.get(), updatedMarshaledResponse); 821 822 // 2. Apply other standard response customizations (CORS, Content-Length) 823 // Note that we don't want to write Content-Length for SSE "accepted" handshakes 824 SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null); 825 boolean suppressContentLength = sseHandshakeResult != null && sseHandshakeResult instanceof SseHandshakeResult.Accepted; 826 827 updatedMarshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), updatedMarshaledResponse, suppressContentLength); 828 829 // Update our result holder with the modified response if necessary 830 if (originalMarshaledResponse != updatedMarshaledResponse) { 831 marshaledResponseHolder.set(updatedMarshaledResponse); 832 requestResultHolder.set(requestResult.copy() 833 .marshaledResponse(updatedMarshaledResponse) 834 .finish()); 835 } 836 837 return updatedMarshaledResponse; 838 } catch (Throwable t) { 839 if (t != resourceMethodResolutionExceptionHolder.get()) { 840 throwables.add(t); 841 842 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_PROCESSING_FAILED, 843 "An exception occurred while processing request") 844 .throwable(t) 845 .request(requestHolder.get()) 846 .resourceMethod(resourceMethodHolder.get()) 847 .build()); 848 } 849 850 // Unhappy path. Try to use configuration's exception response marshaler... 851 try { 852 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 853 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 854 marshaledResponseHolder.set(marshaledResponse); 855 856 return marshaledResponse; 857 } catch (Throwable t2) { 858 throwables.add(t2); 859 860 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 861 format("An exception occurred while trying to write an exception response for %s", t)) 862 .throwable(t2) 863 .request(requestHolder.get()) 864 .resourceMethod(resourceMethodHolder.get()) 865 .build()); 866 867 // The configuration's exception response marshaler failed - provide a failsafe response to recover 868 return provideFailsafeMarshaledResponse(requestHolder.get(), t2); 869 } 870 } 871 }, (interceptorMarshaledResponse) -> { 872 requireNonNull(interceptorMarshaledResponse); 873 didInvokeMarshaledResponseConsumer.set(true); 874 marshaledResponseHolder.set(interceptorMarshaledResponse); 875 }); 876 877 if (!didInvokeMarshaledResponseConsumer.get()) { 878 requestResultHolder.set(null); 879 throw new IllegalStateException(format("%s::interceptRequest must call responseWriter", RequestInterceptor.class.getSimpleName())); 880 } 881 } catch (Throwable t) { 882 throwables.add(t); 883 884 try { 885 // In the event that an error occurs during processing of a RequestInterceptor method, for example 886 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_INTERCEPT_REQUEST_FAILED, 887 format("An exception occurred while invoking %s::interceptRequest", RequestInterceptor.class.getSimpleName())) 888 .throwable(t) 889 .request(requestHolder.get()) 890 .resourceMethod(resourceMethodHolder.get()) 891 .build()); 892 893 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 894 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 895 marshaledResponseHolder.set(marshaledResponse); 896 } catch (Throwable t2) { 897 throwables.add(t2); 898 899 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 900 format("An exception occurred while invoking %s::forThrowable when trying to write an exception response for %s", ResponseMarshaler.class.getSimpleName(), t)) 901 .throwable(t2) 902 .request(requestHolder.get()) 903 .resourceMethod(resourceMethodHolder.get()) 904 .build()); 905 906 marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t2)); 907 } 908 } finally { 909 try { 910 try { 911 lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()); 912 } finally { 913 willStartResponseWritingCompleted.set(true); 914 } 915 916 safelyCollectMetrics.accept( 917 format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()), 918 (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get())); 919 920 Instant responseWriteStarted = Instant.now(); 921 922 try { 923 HttpRequestResult requestResult = requestResultHolder.get(); 924 925 if (requestResult != null) 926 requestResultConsumer.accept(requestResult); 927 else 928 requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get()) 929 .resourceMethod(resourceMethodHolder.get()) 930 .build()); 931 932 Instant responseWriteFinished = Instant.now(); 933 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 934 935 try { 936 lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration); 937 } catch (Throwable t) { 938 throwables.add(t); 939 940 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 941 format("An exception occurred while invoking %s::didWriteResponse", 942 LifecycleObserver.class.getSimpleName())) 943 .throwable(t) 944 .request(requestHolder.get()) 945 .resourceMethod(resourceMethodHolder.get()) 946 .marshaledResponse(marshaledResponseHolder.get()) 947 .build()); 948 } finally { 949 didFinishResponseWritingCompleted.set(true); 950 } 951 952 safelyCollectMetrics.accept( 953 format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()), 954 (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 955 marshaledResponseHolder.get(), responseWriteDuration)); 956 } catch (Throwable t) { 957 throwables.add(t); 958 959 Instant responseWriteFinished = Instant.now(); 960 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 961 962 try { 963 lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t); 964 } catch (Throwable t2) { 965 throwables.add(t2); 966 967 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 968 format("An exception occurred while invoking %s::didFailToWriteResponse", 969 LifecycleObserver.class.getSimpleName())) 970 .throwable(t2) 971 .request(requestHolder.get()) 972 .resourceMethod(resourceMethodHolder.get()) 973 .marshaledResponse(marshaledResponseHolder.get()) 974 .build()); 975 } 976 977 safelyCollectMetrics.accept( 978 format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()), 979 (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 980 marshaledResponseHolder.get(), responseWriteDuration, t)); 981 } 982 } finally { 983 Duration processingDuration = Duration.between(processingStarted, Instant.now()); 984 985 safelyCollectMetrics.accept( 986 format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()), 987 (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables))); 988 989 try { 990 lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)); 991 } catch (Throwable t) { 992 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED, 993 format("An exception occurred while invoking %s::didFinishRequestHandling", 994 LifecycleObserver.class.getSimpleName())) 995 .throwable(t) 996 .request(requestHolder.get()) 997 .resourceMethod(resourceMethodHolder.get()) 998 .marshaledResponse(marshaledResponseHolder.get()) 999 .build()); 1000 } finally { 1001 didFinishRequestHandlingCompleted.set(true); 1002 } 1003 } 1004 } 1005 }); 1006 1007 if (!didInvokeWrapRequestConsumer.get()) 1008 throw new IllegalStateException(format("%s::wrapRequest must call requestProcessor", RequestInterceptor.class.getSimpleName())); 1009 } catch (Throwable t) { 1010 // If an error occurred during request wrapping, it's possible a response was never written/communicated back to LifecycleObserver. 1011 // Detect that here and inform LifecycleObserver accordingly. 1012 safelyLog.accept(LogEvent.with(LogEventType.REQUEST_INTERCEPTOR_WRAP_REQUEST_FAILED, 1013 format("An exception occurred while invoking %s::wrapRequest", 1014 RequestInterceptor.class.getSimpleName())) 1015 .throwable(t) 1016 .request(requestHolder.get()) 1017 .resourceMethod(resourceMethodHolder.get()) 1018 .marshaledResponse(marshaledResponseHolder.get()) 1019 .build()); 1020 1021 // If we don't have a response, let the marshaler try to make one for the exception. 1022 // If that fails, use the failsafe. 1023 if (marshaledResponseHolder.get() == null) { 1024 try { 1025 MarshaledResponse marshaledResponse = responseMarshaler.forThrowable(requestHolder.get(), t, resourceMethodHolder.get()); 1026 marshaledResponse = applyCommonPropertiesToMarshaledResponse(requestHolder.get(), marshaledResponse); 1027 marshaledResponseHolder.set(marshaledResponse); 1028 } catch (Throwable t2) { 1029 throwables.add(t2); 1030 1031 safelyLog.accept(LogEvent.with(LogEventType.RESPONSE_MARSHALER_FOR_THROWABLE_FAILED, 1032 format("An exception occurred during request wrapping while invoking %s::forThrowable", 1033 ResponseMarshaler.class.getSimpleName())) 1034 .throwable(t2) 1035 .request(requestHolder.get()) 1036 .resourceMethod(resourceMethodHolder.get()) 1037 .marshaledResponse(marshaledResponseHolder.get()) 1038 .build()); 1039 1040 marshaledResponseHolder.set(provideFailsafeMarshaledResponse(requestHolder.get(), t)); 1041 } 1042 } 1043 1044 if (!willStartResponseWritingCompleted.get()) { 1045 try { 1046 lifecycleObserver.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get()); 1047 } catch (Throwable t2) { 1048 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_WILL_WRITE_RESPONSE_FAILED, 1049 format("An exception occurred while invoking %s::willWriteResponse", 1050 LifecycleObserver.class.getSimpleName())) 1051 .throwable(t2) 1052 .request(requestHolder.get()) 1053 .resourceMethod(resourceMethodHolder.get()) 1054 .marshaledResponse(marshaledResponseHolder.get()) 1055 .build()); 1056 } 1057 1058 safelyCollectMetrics.accept( 1059 format("An exception occurred while invoking %s::willWriteResponse", MetricsCollector.class.getSimpleName()), 1060 (metricsInvocation) -> metricsInvocation.willWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get())); 1061 } 1062 1063 try { 1064 Instant responseWriteStarted = Instant.now(); 1065 1066 if (!didFinishResponseWritingCompleted.get()) { 1067 try { 1068 HttpRequestResult requestResult = requestResultHolder.get(); 1069 1070 if (requestResult != null) 1071 requestResultConsumer.accept(requestResult); 1072 else 1073 requestResultConsumer.accept(HttpRequestResult.withMarshaledResponse(marshaledResponseHolder.get()) 1074 .resourceMethod(resourceMethodHolder.get()) 1075 .build()); 1076 1077 Instant responseWriteFinished = Instant.now(); 1078 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 1079 1080 try { 1081 lifecycleObserver.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration); 1082 } catch (Throwable t2) { 1083 throwables.add(t2); 1084 1085 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 1086 format("An exception occurred while invoking %s::didWriteResponse", 1087 LifecycleObserver.class.getSimpleName())) 1088 .throwable(t2) 1089 .request(requestHolder.get()) 1090 .resourceMethod(resourceMethodHolder.get()) 1091 .marshaledResponse(marshaledResponseHolder.get()) 1092 .build()); 1093 } 1094 1095 safelyCollectMetrics.accept( 1096 format("An exception occurred while invoking %s::didWriteResponse", MetricsCollector.class.getSimpleName()), 1097 (metricsInvocation) -> metricsInvocation.didWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 1098 marshaledResponseHolder.get(), responseWriteDuration)); 1099 } catch (Throwable t2) { 1100 throwables.add(t2); 1101 1102 Instant responseWriteFinished = Instant.now(); 1103 Duration responseWriteDuration = Duration.between(responseWriteStarted, responseWriteFinished); 1104 1105 try { 1106 lifecycleObserver.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), responseWriteDuration, t); 1107 } catch (Throwable t3) { 1108 throwables.add(t3); 1109 1110 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_WRITE_RESPONSE_FAILED, 1111 format("An exception occurred while invoking %s::didFailToWriteResponse", 1112 LifecycleObserver.class.getSimpleName())) 1113 .throwable(t3) 1114 .request(requestHolder.get()) 1115 .resourceMethod(resourceMethodHolder.get()) 1116 .marshaledResponse(marshaledResponseHolder.get()) 1117 .build()); 1118 } 1119 1120 safelyCollectMetrics.accept( 1121 format("An exception occurred while invoking %s::didFailToWriteResponse", MetricsCollector.class.getSimpleName()), 1122 (metricsInvocation) -> metricsInvocation.didFailToWriteResponse(serverType, requestHolder.get(), resourceMethodHolder.get(), 1123 marshaledResponseHolder.get(), responseWriteDuration, t)); 1124 } 1125 } 1126 } finally { 1127 if (!didFinishRequestHandlingCompleted.get()) { 1128 Duration processingDuration = Duration.between(processingStarted, Instant.now()); 1129 1130 safelyCollectMetrics.accept( 1131 format("An exception occurred while invoking %s::didFinishRequestHandling", MetricsCollector.class.getSimpleName()), 1132 (metricsInvocation) -> metricsInvocation.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables))); 1133 1134 try { 1135 lifecycleObserver.didFinishRequestHandling(serverType, requestHolder.get(), resourceMethodHolder.get(), marshaledResponseHolder.get(), processingDuration, Collections.unmodifiableList(throwables)); 1136 } catch (Throwable t2) { 1137 safelyLog.accept(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_FINISH_REQUEST_HANDLING_FAILED, 1138 format("An exception occurred while invoking %s::didFinishRequestHandling", 1139 LifecycleObserver.class.getSimpleName())) 1140 .throwable(t2) 1141 .request(requestHolder.get()) 1142 .resourceMethod(resourceMethodHolder.get()) 1143 .marshaledResponse(marshaledResponseHolder.get()) 1144 .build()); 1145 } 1146 } 1147 } 1148 } 1149 } 1150 1151 @NonNull 1152 protected HttpRequestResult toHttpRequestResult(@NonNull Request request, 1153 @Nullable ResourceMethod resourceMethod, 1154 @NonNull ServerType serverType) throws Throwable { 1155 requireNonNull(request); 1156 requireNonNull(serverType); 1157 1158 ResourceMethodParameterProvider resourceMethodParameterProvider = getSokletConfig().getResourceMethodParameterProvider(); 1159 InstanceProvider instanceProvider = getSokletConfig().getInstanceProvider(); 1160 CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer(); 1161 ResourceMethodResolver resourceMethodResolver = getSokletConfig().getResourceMethodResolver(); 1162 ResponseMarshaler responseMarshaler = getSokletConfig().getResponseMarshaler(); 1163 CorsPreflight corsPreflight = request.getCorsPreflight().orElse(null); 1164 1165 // Special short-circuit for big requests 1166 if (request.isContentTooLarge()) 1167 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forContentTooLarge(request, resourceMethodResolver.resourceMethodForRequest(request, serverType).orElse(null))) 1168 .resourceMethod(resourceMethod) 1169 .build(); 1170 1171 // Special short-circuit for OPTIONS * 1172 if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH) 1173 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptionsSplat(request)).build(); 1174 1175 // No resource method was found for this HTTP method and path. 1176 if (resourceMethod == null) { 1177 // If this was an OPTIONS request, do special processing. 1178 // If not, figure out if we should return a 404 or 405. 1179 if (request.getHttpMethod() == HttpMethod.OPTIONS) { 1180 // See what methods are available to us for this request's path 1181 Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType); 1182 1183 // Special handling for CORS preflight requests, if needed 1184 if (corsPreflight != null) { 1185 // Let configuration function determine if we should authorize this request. 1186 // Discard any OPTIONS references - see https://stackoverflow.com/a/68529748 1187 Map<HttpMethod, ResourceMethod> nonOptionsMatchingResourceMethodsByHttpMethod = matchingResourceMethodsByHttpMethod.entrySet().stream() 1188 .filter(entry -> entry.getKey() != HttpMethod.OPTIONS) 1189 .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); 1190 1191 CorsPreflightResponse corsPreflightResponse = corsAuthorizer.authorizePreflight(request, corsPreflight, nonOptionsMatchingResourceMethodsByHttpMethod).orElse(null); 1192 1193 // Allow or reject CORS depending on what the function said to do 1194 if (corsPreflightResponse != null) { 1195 // Allow 1196 MarshaledResponse marshaledResponse = responseMarshaler.forCorsPreflightAllowed(request, corsPreflight, corsPreflightResponse); 1197 1198 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1199 .corsPreflightResponse(corsPreflightResponse) 1200 .resourceMethod(resourceMethod) 1201 .build(); 1202 } 1203 1204 // Reject 1205 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forCorsPreflightRejected(request, corsPreflight)) 1206 .resourceMethod(resourceMethod) 1207 .build(); 1208 } else { 1209 // Just a normal OPTIONS response (non-CORS-preflight). 1210 // If there's a matching OPTIONS resource method for this OPTIONS request, then invoke it. 1211 ResourceMethod optionsResourceMethod = matchingResourceMethodsByHttpMethod.get(HttpMethod.OPTIONS); 1212 1213 if (optionsResourceMethod != null) { 1214 resourceMethod = optionsResourceMethod; 1215 } else { 1216 Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(matchingResourceMethodsByHttpMethod, true); 1217 1218 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forOptions(request, allowedHttpMethods)) 1219 .resourceMethod(resourceMethod) 1220 .build(); 1221 } 1222 } 1223 } else if (request.getHttpMethod() == HttpMethod.HEAD) { 1224 // If there's a matching GET resource method for this HEAD request, then invoke it 1225 Request headGetRequest = request.copy().httpMethod(HttpMethod.GET).finish(); 1226 ResourceMethod headGetResourceMethod = resourceMethodResolver.resourceMethodForRequest(headGetRequest, serverType).orElse(null); 1227 1228 if (headGetResourceMethod != null) 1229 resourceMethod = headGetResourceMethod; 1230 else 1231 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request)) 1232 .resourceMethod(resourceMethod) 1233 .build(); 1234 } else { 1235 // Not an OPTIONS request, so it's possible we have a 405. See if other HTTP methods match... 1236 Map<HttpMethod, ResourceMethod> otherMatchingResourceMethodsByHttpMethod = resolveMatchingResourceMethodsByHttpMethod(request, resourceMethodResolver, serverType); 1237 1238 Set<HttpMethod> matchingNonOptionsHttpMethods = otherMatchingResourceMethodsByHttpMethod.keySet().stream() 1239 .filter(httpMethod -> httpMethod != HttpMethod.OPTIONS) 1240 .collect(Collectors.toSet()); 1241 1242 if (matchingNonOptionsHttpMethods.size() > 0) { 1243 // ...if some do, it's a 405 1244 Set<HttpMethod> allowedHttpMethods = allowedHttpMethodsForResponse(otherMatchingResourceMethodsByHttpMethod, true); 1245 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forMethodNotAllowed(request, allowedHttpMethods)) 1246 .resourceMethod(resourceMethod) 1247 .build(); 1248 } else { 1249 // no matching resource method found, it's a 404 1250 return HttpRequestResult.withMarshaledResponse(responseMarshaler.forNotFound(request)) 1251 .resourceMethod(resourceMethod) 1252 .build(); 1253 } 1254 } 1255 } 1256 1257 // Found a resource method - happy path. 1258 // 1. Get an instance of the resource class 1259 // 2. Get values to pass to the resource method on the resource class 1260 // 3. Invoke the resource method and use its return value to drive a response 1261 Class<?> resourceClass = resourceMethod.getMethod().getDeclaringClass(); 1262 Object resourceClassInstance; 1263 1264 try { 1265 resourceClassInstance = instanceProvider.provide(resourceClass); 1266 } catch (Exception e) { 1267 throw new IllegalArgumentException(format("Unable to acquire an instance of %s", resourceClass.getName()), e); 1268 } 1269 1270 List<Object> parameterValues = resourceMethodParameterProvider.parameterValuesForResourceMethod(request, resourceMethod); 1271 1272 Object responseObject; 1273 1274 try { 1275 responseObject = resourceMethod.getMethod().invoke(resourceClassInstance, parameterValues.toArray()); 1276 } catch (InvocationTargetException e) { 1277 if (e.getTargetException() != null) 1278 throw e.getTargetException(); 1279 1280 throw e; 1281 } 1282 1283 // Unwrap the Optional<T>, if one exists. We do not recurse deeper than one level 1284 if (responseObject instanceof Optional<?>) 1285 responseObject = ((Optional<?>) responseObject).orElse(null); 1286 1287 Response response; 1288 SseHandshakeResult sseHandshakeResult = null; 1289 1290 // If null/void return, it's a 204 1291 // If it's a MarshaledResponse object, no marshaling + return it immediately - caller knows exactly what it wants to write. 1292 // If it's a Response object, use as is. 1293 // If it's a non-Response type of object, assume it's the response body and wrap in a Response. 1294 if (responseObject == null) { 1295 response = Response.withStatusCode(204).build(); 1296 } else if (responseObject instanceof MarshaledResponse) { 1297 MarshaledResponse marshaledResponse = (MarshaledResponse) responseObject; 1298 enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent() || marshaledResponse.getStream().isPresent()); 1299 1300 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1301 .resourceMethod(resourceMethod) 1302 .build(); 1303 } else if (responseObject instanceof Response) { 1304 response = (Response) responseObject; 1305 } else if (responseObject instanceof SseHandshakeResult.Accepted accepted) { // SSE "accepted" handshake 1306 return HttpRequestResult.withMarshaledResponse(toMarshaledResponse(accepted)) 1307 .resourceMethod(resourceMethod) 1308 .sseHandshakeResult(accepted) 1309 .build(); 1310 } else if (responseObject instanceof SseHandshakeResult.Rejected rejected) { // SSE "rejected" handshake 1311 response = rejected.getResponse(); 1312 sseHandshakeResult = rejected; 1313 } else { 1314 response = Response.withStatusCode(200).body(responseObject).build(); 1315 } 1316 1317 enforceBodylessStatusCode(response.getStatusCode(), response.getBody().isPresent()); 1318 1319 MarshaledResponse marshaledResponse = responseMarshaler.forResourceMethod(request, response, resourceMethod); 1320 1321 enforceBodylessStatusCode(marshaledResponse.getStatusCode(), marshaledResponse.getBody().isPresent() || marshaledResponse.getStream().isPresent()); 1322 1323 return HttpRequestResult.withMarshaledResponse(marshaledResponse) 1324 .response(response) 1325 .resourceMethod(resourceMethod) 1326 .sseHandshakeResult(sseHandshakeResult) 1327 .build(); 1328 } 1329 1330 @NonNull 1331 private MarshaledResponse toMarshaledResponse(SseHandshakeResult.@NonNull Accepted accepted) { 1332 requireNonNull(accepted); 1333 1334 Map<String, Set<String>> headers = accepted.getHeaders(); 1335 LinkedCaseInsensitiveMap<Set<String>> finalHeaders = new LinkedCaseInsensitiveMap<>(DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.size() + headers.size()); 1336 1337 // Start with defaults 1338 for (Map.Entry<String, Set<String>> e : DEFAULT_ACCEPTED_HANDSHAKE_HEADERS.entrySet()) 1339 finalHeaders.put(e.getKey(), e.getValue()); // values already unmodifiable 1340 1341 // Overlay user-supplied headers (prefer user values on key collision) 1342 for (Map.Entry<String, Set<String>> e : headers.entrySet()) { 1343 // Defensively copy so callers can't mutate after construction 1344 Set<String> values = e.getValue() == null ? Set.of() : Set.copyOf(e.getValue()); 1345 finalHeaders.put(e.getKey(), values); 1346 } 1347 1348 return MarshaledResponse.withStatusCode(200) 1349 .headers(finalHeaders) 1350 .cookies(accepted.getCookies()) 1351 .build(); 1352 } 1353 1354 private static void enforceBodylessStatusCode(@NonNull Integer statusCode, 1355 @NonNull Boolean hasBody) { 1356 requireNonNull(statusCode); 1357 requireNonNull(hasBody); 1358 1359 if (hasBody && isBodylessStatusCode(statusCode)) 1360 throw new IllegalStateException(format("HTTP status code %d must not include a response body", statusCode)); 1361 } 1362 1363 private static boolean isBodylessStatusCode(@NonNull Integer statusCode) { 1364 requireNonNull(statusCode); 1365 return (statusCode >= 100 && statusCode < 200) || statusCode == 204 || statusCode == 304; 1366 } 1367 1368 @NonNull 1369 protected MarshaledResponse applyHeadResponseIfApplicable(@NonNull Request request, 1370 @NonNull MarshaledResponse marshaledResponse) { 1371 if (request.getHttpMethod() != HttpMethod.HEAD) 1372 return marshaledResponse; 1373 1374 return getSokletConfig().getResponseMarshaler().forHead(request, marshaledResponse); 1375 } 1376 1377 // Hat tip to Aslan Parçası and GrayStar 1378 @NonNull 1379 protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request, 1380 @NonNull MarshaledResponse marshaledResponse) { 1381 requireNonNull(request); 1382 requireNonNull(marshaledResponse); 1383 1384 return applyCommonPropertiesToMarshaledResponse(request, marshaledResponse, false); 1385 } 1386 1387 protected void handleMcpRequest(@NonNull Request request, 1388 @NonNull Consumer<HttpRequestResult> requestResultConsumer) { 1389 requireNonNull(request); 1390 requireNonNull(requestResultConsumer); 1391 1392 requestResultConsumer.accept(this.defaultMcpRuntime.handleRequest(request)); 1393 } 1394 1395 protected void handleSimulatedMcpStreamDisconnect(@NonNull Request request, 1396 @NonNull String sessionId) { 1397 requireNonNull(request); 1398 requireNonNull(sessionId); 1399 1400 this.defaultMcpRuntime.handleClientDisconnectedStream(request, sessionId); 1401 } 1402 1403 @NonNull 1404 protected MarshaledResponse applyCommonPropertiesToMarshaledResponse(@NonNull Request request, 1405 @NonNull MarshaledResponse marshaledResponse, 1406 @NonNull Boolean suppressContentLength) { 1407 requireNonNull(request); 1408 requireNonNull(marshaledResponse); 1409 requireNonNull(suppressContentLength); 1410 1411 // Don't write Content-Length for an accepted SSE Handshake, for example 1412 if (!suppressContentLength) 1413 marshaledResponse = applyContentLengthIfApplicable(request, marshaledResponse); 1414 1415 // If the Date header is missing, add it using our cached provider 1416 if (!marshaledResponse.getHeaders().containsKey("Date")) 1417 marshaledResponse = marshaledResponse.copy() 1418 .headers(headers -> headers.put("Date", Set.of(HttpDate.currentSecondHeaderValue()))) 1419 .finish(); 1420 1421 marshaledResponse = applyCorsResponseIfApplicable(request, marshaledResponse); 1422 1423 return marshaledResponse; 1424 } 1425 1426 @NonNull 1427 protected MarshaledResponse applyContentLengthIfApplicable(@NonNull Request request, 1428 @NonNull MarshaledResponse marshaledResponse) { 1429 requireNonNull(request); 1430 requireNonNull(marshaledResponse); 1431 1432 if (marshaledResponse.isStreaming()) 1433 return marshaledResponse; 1434 1435 Set<String> normalizedHeaderNames = marshaledResponse.getHeaders().keySet().stream() 1436 .map(headerName -> headerName.toLowerCase(Locale.US)) 1437 .collect(Collectors.toSet()); 1438 1439 // If Content-Length is already specified, don't do anything 1440 if (normalizedHeaderNames.contains("content-length") || normalizedHeaderNames.contains("transfer-encoding")) 1441 return marshaledResponse; 1442 1443 // If Content-Length is not specified, specify as the number of bytes in the body 1444 return marshaledResponse.copy() 1445 .headers((mutableHeaders) -> { 1446 String contentLengthHeaderValue = String.valueOf(marshaledResponse.getBodyLength()); 1447 mutableHeaders.put("Content-Length", Set.of(contentLengthHeaderValue)); 1448 }).finish(); 1449 } 1450 1451 @NonNull 1452 protected MarshaledResponse applyCorsResponseIfApplicable(@NonNull Request request, 1453 @NonNull MarshaledResponse marshaledResponse) { 1454 requireNonNull(request); 1455 requireNonNull(marshaledResponse); 1456 1457 Cors cors = request.getCors().orElse(null); 1458 1459 // If non-CORS request, nothing further to do (note that CORS preflight was handled earlier) 1460 if (cors == null) 1461 return marshaledResponse; 1462 1463 CorsAuthorizer corsAuthorizer = getSokletConfig().getCorsAuthorizer(); 1464 1465 // Does the authorizer say we are authorized? 1466 CorsResponse corsResponse = corsAuthorizer.authorize(request, cors).orElse(null); 1467 1468 // Not authorized - don't apply CORS headers to the response 1469 if (corsResponse == null) 1470 return marshaledResponse; 1471 1472 // Authorized - OK, let's apply the headers to the response 1473 return getSokletConfig().getResponseMarshaler().forCorsAllowed(request, cors, corsResponse, marshaledResponse); 1474 } 1475 1476 @NonNull 1477 protected Map<@NonNull HttpMethod, @NonNull ResourceMethod> resolveMatchingResourceMethodsByHttpMethod(@NonNull Request request, 1478 @NonNull ResourceMethodResolver resourceMethodResolver, 1479 @NonNull ServerType serverType) { 1480 requireNonNull(request); 1481 requireNonNull(resourceMethodResolver); 1482 requireNonNull(serverType); 1483 1484 // Special handling for OPTIONS * 1485 if (request.getResourcePath() == ResourcePath.OPTIONS_SPLAT_RESOURCE_PATH) 1486 return new LinkedHashMap<>(); 1487 1488 Map<HttpMethod, ResourceMethod> matchingResourceMethodsByHttpMethod = new LinkedHashMap<>(HttpMethod.values().length); 1489 1490 for (HttpMethod httpMethod : HttpMethod.values()) { 1491 // Make a quick copy of the request to see if other paths match 1492 Request otherRequest = Request.withPath(httpMethod, request.getPath()).build(); 1493 ResourceMethod resourceMethod = resourceMethodResolver.resourceMethodForRequest(otherRequest, serverType).orElse(null); 1494 1495 if (resourceMethod != null) 1496 matchingResourceMethodsByHttpMethod.put(httpMethod, resourceMethod); 1497 } 1498 1499 return matchingResourceMethodsByHttpMethod; 1500 } 1501 1502 @NonNull 1503 private static Set<@NonNull HttpMethod> allowedHttpMethodsForResponse(@NonNull Map<@NonNull HttpMethod, @NonNull ResourceMethod> matchingResourceMethodsByHttpMethod, 1504 @NonNull Boolean includeOptions) { 1505 requireNonNull(matchingResourceMethodsByHttpMethod); 1506 requireNonNull(includeOptions); 1507 1508 Set<HttpMethod> allowedHttpMethods = EnumSet.noneOf(HttpMethod.class); 1509 allowedHttpMethods.addAll(matchingResourceMethodsByHttpMethod.keySet()); 1510 1511 if (includeOptions) 1512 allowedHttpMethods.add(HttpMethod.OPTIONS); 1513 1514 if (matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.GET) || matchingResourceMethodsByHttpMethod.containsKey(HttpMethod.HEAD)) 1515 allowedHttpMethods.add(HttpMethod.HEAD); 1516 1517 return allowedHttpMethods; 1518 } 1519 1520 @NonNull 1521 protected MarshaledResponse provideFailsafeMarshaledResponse(@NonNull Request request, 1522 @NonNull Throwable throwable) { 1523 requireNonNull(request); 1524 requireNonNull(throwable); 1525 1526 Integer statusCode = 500; 1527 Charset charset = StandardCharsets.UTF_8; 1528 1529 return MarshaledResponse.withStatusCode(statusCode) 1530 .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", charset.name())))) 1531 .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(charset)) 1532 .build(); 1533 } 1534 1535 /** 1536 * Synonym for {@link #stop()}. 1537 */ 1538 @Override 1539 public void close() { 1540 stop(); 1541 } 1542 1543 /** 1544 * Is any managed transport server started? 1545 * 1546 * @return {@code true} if at least one configured transport server is started, {@code false} otherwise 1547 */ 1548 @NonNull 1549 public Boolean isStarted() { 1550 getLock().lock(); 1551 1552 try { 1553 HttpServer httpServer = getSokletConfig().getHttpServer().orElse(null); 1554 1555 if (httpServer != null && httpServer.isStarted()) 1556 return true; 1557 1558 SseServer sseServer = getSokletConfig().getSseServer().orElse(null); 1559 if (sseServer != null && sseServer.isStarted()) 1560 return true; 1561 1562 McpServer mcpServer = getSokletConfig().getMcpServer().orElse(null); 1563 return mcpServer != null && mcpServer.isStarted(); 1564 } finally { 1565 getLock().unlock(); 1566 } 1567 } 1568 1569 /** 1570 * Runs Soklet with special non-network "simulator" implementations of the configured transport servers - useful for integration testing. 1571 * <p> 1572 * See <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a> for how to write these tests. 1573 * 1574 * @param sokletConfig configuration that drives the Soklet system 1575 * @param simulatorConsumer code to execute within the context of the simulator 1576 */ 1577 public static void runSimulator(@NonNull SokletConfig sokletConfig, 1578 @NonNull Consumer<Simulator> simulatorConsumer) { 1579 runSimulator(sokletConfig, SimulatorOptions.defaultInstance(), simulatorConsumer); 1580 } 1581 1582 /** 1583 * Runs Soklet with special non-network "simulator" implementations of the configured transport servers - useful for integration testing. 1584 * <p> 1585 * See <a href="https://www.soklet.com/docs/testing">https://www.soklet.com/docs/testing</a> for how to write these tests. 1586 * 1587 * @param sokletConfig configuration that drives the Soklet system 1588 * @param simulatorOptions simulator behavior options 1589 * @param simulatorConsumer code to execute within the context of the simulator 1590 */ 1591 public static void runSimulator(@NonNull SokletConfig sokletConfig, 1592 @NonNull SimulatorOptions simulatorOptions, 1593 @NonNull Consumer<Simulator> simulatorConsumer) { 1594 requireNonNull(sokletConfig); 1595 requireNonNull(simulatorOptions); 1596 requireNonNull(simulatorConsumer); 1597 1598 // Create Soklet instance - this initializes the REAL implementations through proxies 1599 Soklet soklet = Soklet.fromConfig(sokletConfig); 1600 1601 // Extract proxies (they're guaranteed to be proxies now) 1602 HttpServerProxy serverProxy = sokletConfig.getHttpServer() 1603 .map(server -> (HttpServerProxy) server) 1604 .orElse(null); 1605 SseServerProxy sseServerProxy = sokletConfig.getSseServer() 1606 .map(s -> (SseServerProxy) s) 1607 .orElse(null); 1608 McpServerProxy mcpServerProxy = sokletConfig.getMcpServer() 1609 .map(mcpServer -> (McpServerProxy) mcpServer) 1610 .orElse(null); 1611 1612 // Create mock implementations 1613 MockHttpServer mockServer = serverProxy == null ? null : new MockHttpServer(); 1614 MockSseServer mockSseServer = new MockSseServer(); 1615 MockMcpServer mockMcpServer = mcpServerProxy == null ? null : new MockMcpServer(mcpServerProxy.getRealImplementation()); 1616 1617 // Switch proxies to simulator mode 1618 if (serverProxy != null) 1619 serverProxy.enableSimulatorMode(mockServer); 1620 1621 if (sseServerProxy != null) 1622 sseServerProxy.enableSimulatorMode(mockSseServer); 1623 1624 if (mcpServerProxy != null) 1625 mcpServerProxy.enableSimulatorMode(mockMcpServer); 1626 1627 try { 1628 // Initialize mocks with request handlers that delegate to Soklet's processing 1629 if (mockServer != null) 1630 mockServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 1631 // Delegate to Soklet's internal request handling 1632 soklet.handleRequest(request, ServerType.STANDARD_HTTP, marshaledResponseConsumer); 1633 }); 1634 1635 if (mockSseServer != null) 1636 mockSseServer.initialize(sokletConfig, (request, marshaledResponseConsumer) -> { 1637 // Delegate to Soklet's internal request handling for SSE 1638 soklet.handleRequest(request, ServerType.SSE, marshaledResponseConsumer); 1639 }); 1640 1641 if (mockMcpServer != null) 1642 mockMcpServer.initialize(sokletConfig, soklet::handleMcpRequest); 1643 1644 if (mockMcpServer != null) 1645 mockMcpServer.onClientDisconnectedMcpStream(soklet::handleSimulatedMcpStreamDisconnect); 1646 1647 // Create and provide simulator 1648 Simulator simulator = new DefaultSimulator(mockServer, mockSseServer, mockMcpServer, simulatorOptions); 1649 simulatorConsumer.accept(simulator); 1650 } finally { 1651 // Always restore to real implementations 1652 if (serverProxy != null) 1653 serverProxy.disableSimulatorMode(); 1654 1655 if (sseServerProxy != null) 1656 sseServerProxy.disableSimulatorMode(); 1657 1658 if (mcpServerProxy != null) 1659 mcpServerProxy.disableSimulatorMode(); 1660 } 1661 } 1662 1663 @NonNull 1664 protected SokletConfig getSokletConfig() { 1665 return this.sokletConfig; 1666 } 1667 1668 @NonNull 1669 protected ReentrantLock getLock() { 1670 return this.lock; 1671 } 1672 1673 @NonNull 1674 protected AtomicReference<CountDownLatch> getAwaitShutdownLatchReference() { 1675 return this.awaitShutdownLatchReference; 1676 } 1677 1678 @ThreadSafe 1679 static class DefaultSimulator implements Simulator { 1680 @Nullable 1681 private MockHttpServer server; 1682 @Nullable 1683 private MockSseServer sseServer; 1684 @Nullable 1685 private MockMcpServer mcpServer; 1686 @NonNull 1687 private final SimulatorOptions simulatorOptions; 1688 1689 public DefaultSimulator(@Nullable MockHttpServer server, 1690 @Nullable MockSseServer sseServer, 1691 @Nullable MockMcpServer mcpServer) { 1692 this(server, sseServer, mcpServer, SimulatorOptions.defaultInstance()); 1693 } 1694 1695 public DefaultSimulator(@Nullable MockHttpServer server, 1696 @Nullable MockSseServer sseServer, 1697 @Nullable MockMcpServer mcpServer, 1698 @NonNull SimulatorOptions simulatorOptions) { 1699 this.server = server; 1700 this.sseServer = sseServer; 1701 this.mcpServer = mcpServer; 1702 this.simulatorOptions = requireNonNull(simulatorOptions); 1703 } 1704 1705 @NonNull 1706 @Override 1707 public HttpRequestResult performHttpRequest(@NonNull Request request) { 1708 MockHttpServer server = getHttpServer().orElse(null); 1709 1710 if (server == null) 1711 throw new IllegalStateException(format("You must specify a %s in your %s to simulate requests", 1712 HttpServer.class.getSimpleName(), SokletConfig.class.getSimpleName())); 1713 1714 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 1715 HttpServer.RequestHandler requestHandler = server.getRequestHandler().orElse(null); 1716 1717 if (requestHandler == null) 1718 throw new IllegalStateException("You must register a request handler prior to simulating requests"); 1719 1720 requestHandler.handleRequest(request, (requestResult -> { 1721 requestResultHolder.set(requestResult); 1722 })); 1723 1724 return materializeStreamingResponse(request, requestResultHolder.get()); 1725 } 1726 1727 @NonNull 1728 private HttpRequestResult materializeStreamingResponse(@NonNull Request request, 1729 @Nullable HttpRequestResult requestResult) { 1730 requireNonNull(request); 1731 1732 if (requestResult == null) 1733 throw new IllegalStateException("No HTTP request result was produced by the simulator"); 1734 1735 StreamingResponseBody stream = requestResult.getMarshaledResponse().getStream().orElse(null); 1736 1737 if (stream == null) 1738 return requestResult; 1739 1740 byte[] bytes; 1741 Instant streamStarted = Instant.now(); 1742 1743 try { 1744 bytes = materializeStreamingResponseBody(request, requestResult, stream); 1745 notifyDidTerminateSimulatorResponseStream(request, requestResult, streamStarted, 1746 Duration.between(streamStarted, Instant.now()), null, null); 1747 } catch (StreamingResponseCanceledException e) { 1748 StreamTerminationReason cancelationReason = e.getCancelationReason(); 1749 Throwable cause = e.getCancelationCause().orElse(null); 1750 notifyDidTerminateSimulatorResponseStream(request, requestResult, streamStarted, 1751 Duration.between(streamStarted, Instant.now()), cancelationReason, cause); 1752 throw new IllegalStateException("Simulated streaming response was canceled: " + cancelationReason.name(), e); 1753 } catch (InterruptedException e) { 1754 Thread.currentThread().interrupt(); 1755 notifyDidTerminateSimulatorResponseStream(request, requestResult, streamStarted, 1756 Duration.between(streamStarted, Instant.now()), StreamTerminationReason.CLIENT_DISCONNECTED, e); 1757 throw new IllegalStateException("Simulated streaming response was canceled: CLIENT_DISCONNECTED", e); 1758 } catch (Throwable t) { 1759 notifyDidTerminateSimulatorResponseStream(request, requestResult, streamStarted, 1760 Duration.between(streamStarted, Instant.now()), StreamTerminationReason.PRODUCER_FAILED, t); 1761 1762 if (t instanceof Error error) 1763 throw error; 1764 1765 throw new IllegalStateException("Simulated streaming response failed.", t); 1766 } 1767 1768 MarshaledResponse marshaledResponse = requestResult.getMarshaledResponse().copy() 1769 .withoutStream() 1770 .body(bytes) 1771 .finish(); 1772 1773 return requestResult.copy() 1774 .marshaledResponse(marshaledResponse) 1775 .finish(); 1776 } 1777 1778 private void notifyDidTerminateSimulatorResponseStream(@NonNull Request request, 1779 @NonNull HttpRequestResult requestResult, 1780 @NonNull Instant establishedAt, 1781 @NonNull Duration streamDuration, 1782 @Nullable StreamTerminationReason cancelationReason, 1783 @Nullable Throwable throwable) { 1784 requireNonNull(request); 1785 requireNonNull(requestResult); 1786 requireNonNull(establishedAt); 1787 requireNonNull(streamDuration); 1788 1789 MockHttpServer server = getHttpServer().orElse(null); 1790 SokletConfig sokletConfig = server == null ? null : server.getSokletConfig().orElse(null); 1791 1792 if (sokletConfig == null) 1793 return; 1794 1795 MarshaledResponse marshaledResponse = requestResult.getMarshaledResponse(); 1796 ResourceMethod resourceMethod = requestResult.getResourceMethod().orElse(null); 1797 LifecycleObserver lifecycleObserver = sokletConfig.getAggregateLifecycleObserver(); 1798 StreamingResponseHandle streamingResponse = new DefaultStreamingResponseHandle(ServerType.STANDARD_HTTP, 1799 request, resourceMethod, marshaledResponse, establishedAt); 1800 StreamTermination termination = StreamTermination 1801 .with(cancelationReason == null ? StreamTerminationReason.COMPLETED : cancelationReason, streamDuration) 1802 .cause(throwable) 1803 .build(); 1804 1805 try { 1806 lifecycleObserver.willTerminateResponseStream(streamingResponse, termination); 1807 } catch (Throwable t) { 1808 try { 1809 lifecycleObserver.didReceiveLogEvent(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_WILL_TERMINATE_RESPONSE_STREAM_FAILED, 1810 format("An exception occurred while invoking %s::willTerminateResponseStream", LifecycleObserver.class.getSimpleName())) 1811 .throwable(t) 1812 .request(request) 1813 .resourceMethod(resourceMethod) 1814 .marshaledResponse(marshaledResponse) 1815 .build()); 1816 } catch (Throwable ignored) { 1817 // Keep simulator lifecycle observer failures contained. 1818 } 1819 } 1820 1821 try { 1822 lifecycleObserver.didTerminateResponseStream(streamingResponse, termination); 1823 } catch (Throwable t) { 1824 try { 1825 lifecycleObserver.didReceiveLogEvent(LogEvent.with(LogEventType.LIFECYCLE_OBSERVER_DID_TERMINATE_RESPONSE_STREAM_FAILED, 1826 format("An exception occurred while invoking %s::didTerminateResponseStream", LifecycleObserver.class.getSimpleName())) 1827 .throwable(t) 1828 .request(request) 1829 .resourceMethod(resourceMethod) 1830 .marshaledResponse(marshaledResponse) 1831 .build()); 1832 } catch (Throwable ignored) { 1833 // Keep simulator lifecycle observer failures contained. 1834 } 1835 } 1836 } 1837 1838 private void notifyDidReceiveSimulatorStreamCancelationCallbackFailure(@NonNull Request request, 1839 @NonNull HttpRequestResult requestResult, 1840 @NonNull Throwable throwable) { 1841 requireNonNull(request); 1842 requireNonNull(requestResult); 1843 requireNonNull(throwable); 1844 1845 MockHttpServer server = getHttpServer().orElse(null); 1846 SokletConfig sokletConfig = server == null ? null : server.getSokletConfig().orElse(null); 1847 1848 if (sokletConfig == null) 1849 return; 1850 1851 LifecycleObserver lifecycleObserver = sokletConfig.getAggregateLifecycleObserver(); 1852 ResourceMethod resourceMethod = requestResult.getResourceMethod().orElse(null); 1853 MarshaledResponse marshaledResponse = requestResult.getMarshaledResponse(); 1854 1855 try { 1856 lifecycleObserver.didReceiveLogEvent(LogEvent.with(LogEventType.RESPONSE_STREAM_CANCELATION_CALLBACK_FAILED, 1857 "An exception occurred while invoking a streaming response cancelation callback") 1858 .throwable(throwable) 1859 .request(request) 1860 .resourceMethod(resourceMethod) 1861 .marshaledResponse(marshaledResponse) 1862 .build()); 1863 } catch (Throwable ignored) { 1864 // Keep simulator lifecycle observer failures contained. 1865 } 1866 } 1867 1868 @NonNull 1869 private byte[] materializeStreamingResponseBody(@NonNull Request request, 1870 @NonNull HttpRequestResult requestResult, 1871 @NonNull StreamingResponseBody stream) throws Exception { 1872 requireNonNull(request); 1873 requireNonNull(requestResult); 1874 requireNonNull(stream); 1875 1876 SimulatorCancelationToken cancelationToken = new SimulatorCancelationToken(throwable -> 1877 notifyDidReceiveSimulatorStreamCancelationCallbackFailure(request, requestResult, throwable)); 1878 SimulatorStreamingResponseContext context = new SimulatorStreamingResponseContext(request, cancelationToken); 1879 SimulatorResponseStream output = new SimulatorResponseStream(getSimulatorOptions().getStreamingResponseBodyLimitInBytes(), cancelationToken); 1880 1881 try { 1882 if (stream instanceof StreamingResponseBody.WriterBody writerBody) { 1883 writerBody.getWriter().writeTo(output, context); 1884 } else if (stream instanceof StreamingResponseBody.InputStreamBody inputStreamBody) { 1885 try (java.io.InputStream inputStream = requireNonNull(inputStreamBody.getInputStreamSupplier().get()); 1886 AutoCloseable ignored = context.onCancel(() -> closeQuietly(inputStream))) { 1887 byte[] buffer = new byte[inputStreamBody.getBufferSizeInBytes()]; 1888 int read; 1889 1890 while ((read = inputStream.read(buffer)) >= 0) { 1891 context.throwIfCanceled(); 1892 if (read > 0) 1893 output.write(ByteBuffer.wrap(buffer, 0, read)); 1894 } 1895 } 1896 } else if (stream instanceof StreamingResponseBody.ReaderBody readerBody) { 1897 try (java.io.Reader reader = requireNonNull(readerBody.getReaderSupplier().get()); 1898 AutoCloseable ignored = context.onCancel(() -> closeQuietly(reader))) { 1899 materializeReader(readerBody, reader, output, context); 1900 } 1901 } else if (stream instanceof StreamingResponseBody.PublisherBody publisherBody) { 1902 materializePublisher(publisherBody, output, context); 1903 } else { 1904 throw new IllegalStateException(format("Unsupported streaming response body type: %s", stream.getClass().getName())); 1905 } 1906 } catch (StreamingResponseCanceledException e) { 1907 cancelationToken.cancel(e.getCancelationReason(), e.getCancelationCause().orElse(null)); 1908 throw e; 1909 } catch (InterruptedException e) { 1910 Thread.currentThread().interrupt(); 1911 cancelationToken.cancel(cancelationToken.getCancelationReason() 1912 .orElse(StreamTerminationReason.CLIENT_DISCONNECTED), e); 1913 throw e; 1914 } catch (Throwable t) { 1915 cancelationToken.cancel(StreamTerminationReason.PRODUCER_FAILED, t); 1916 1917 if (t instanceof Exception exception) 1918 throw exception; 1919 1920 if (t instanceof Error error) 1921 throw error; 1922 1923 throw new RuntimeException(t); 1924 } 1925 1926 return output.toByteArray(); 1927 } 1928 1929 private void materializeReader(com.soklet.StreamingResponseBody.@NonNull ReaderBody readerBody, 1930 @NonNull Reader reader, 1931 @NonNull SimulatorResponseStream output, 1932 @NonNull SimulatorStreamingResponseContext context) throws IOException, InterruptedException, StreamingResponseCanceledException, CharacterCodingException { 1933 requireNonNull(readerBody); 1934 requireNonNull(reader); 1935 requireNonNull(output); 1936 requireNonNull(context); 1937 1938 CharsetEncoder encoder = readerBody.newEncoder(); 1939 CharBuffer charBuffer = CharBuffer.allocate(readerBody.getBufferSizeInCharacters()); 1940 ByteBuffer byteBuffer = ByteBuffer.allocate(Math.max(128, (int) Math.ceil(readerBody.getBufferSizeInCharacters() * encoder.maxBytesPerChar()))); 1941 1942 while (reader.read(charBuffer) >= 0) { 1943 context.throwIfCanceled(); 1944 charBuffer.flip(); 1945 encodeCharsForSimulator(encoder, charBuffer, byteBuffer, false, output); 1946 charBuffer.compact(); 1947 } 1948 1949 charBuffer.flip(); 1950 encodeCharsForSimulator(encoder, charBuffer, byteBuffer, true, output); 1951 1952 CoderResult result; 1953 do { 1954 result = encoder.flush(byteBuffer); 1955 writeEncodedBytesForSimulator(byteBuffer, output); 1956 if (result.isError()) 1957 result.throwException(); 1958 } while (result.isOverflow()); 1959 } 1960 1961 private void encodeCharsForSimulator(@NonNull CharsetEncoder encoder, 1962 @NonNull CharBuffer charBuffer, 1963 @NonNull ByteBuffer byteBuffer, 1964 boolean endOfInput, 1965 @NonNull SimulatorResponseStream output) throws IOException, InterruptedException, StreamingResponseCanceledException, CharacterCodingException { 1966 CoderResult result; 1967 1968 do { 1969 result = encoder.encode(charBuffer, byteBuffer, endOfInput); 1970 writeEncodedBytesForSimulator(byteBuffer, output); 1971 1972 if (result.isError()) 1973 result.throwException(); 1974 } while (result.isOverflow()); 1975 } 1976 1977 private void writeEncodedBytesForSimulator(@NonNull ByteBuffer byteBuffer, 1978 @NonNull SimulatorResponseStream output) throws IOException, InterruptedException, StreamingResponseCanceledException { 1979 byteBuffer.flip(); 1980 if (byteBuffer.hasRemaining()) 1981 output.write(byteBuffer); 1982 byteBuffer.clear(); 1983 } 1984 1985 private void materializePublisher(com.soklet.StreamingResponseBody.@NonNull PublisherBody publisherBody, 1986 @NonNull SimulatorResponseStream output, 1987 @NonNull SimulatorStreamingResponseContext context) throws Exception { 1988 requireNonNull(publisherBody); 1989 requireNonNull(output); 1990 requireNonNull(context); 1991 1992 CountDownLatch completed = new CountDownLatch(1); 1993 AtomicBoolean publisherTerminated = new AtomicBoolean(false); 1994 AtomicReference<Throwable> failure = new AtomicReference<>(); 1995 AtomicReference<Flow.Subscription> subscriptionRef = new AtomicReference<>(); 1996 1997 try (AutoCloseable cancelationRegistration = context.onCancel(() -> { 1998 Flow.Subscription subscription = subscriptionRef.get(); 1999 2000 if (subscription != null) 2001 subscription.cancel(); 2002 })) { 2003 publisherBody.getPublisher().subscribe(new Flow.Subscriber<>() { 2004 @Override 2005 public void onSubscribe(Flow.Subscription subscription) { 2006 requireNonNull(subscription); 2007 2008 if (!subscriptionRef.compareAndSet(null, subscription)) { 2009 subscription.cancel(); 2010 return; 2011 } 2012 2013 subscription.request(1L); 2014 } 2015 2016 @Override 2017 public void onNext(ByteBuffer item) { 2018 Flow.Subscription subscription = subscriptionRef.get(); 2019 2020 try { 2021 context.throwIfCanceled(); 2022 output.write(requireNonNull(item)); 2023 context.throwIfCanceled(); 2024 } catch (Throwable t) { 2025 failure.compareAndSet(null, t); 2026 publisherTerminated.set(true); 2027 2028 if (subscription != null) 2029 subscription.cancel(); 2030 2031 completed.countDown(); 2032 return; 2033 } 2034 2035 if (subscription != null) 2036 subscription.request(1L); 2037 } 2038 2039 @Override 2040 public void onError(Throwable throwable) { 2041 publisherTerminated.set(true); 2042 failure.compareAndSet(null, throwable == null 2043 ? new IllegalStateException("Publisher failed without an error") 2044 : throwable); 2045 completed.countDown(); 2046 } 2047 2048 @Override 2049 public void onComplete() { 2050 publisherTerminated.set(true); 2051 completed.countDown(); 2052 } 2053 }); 2054 2055 while (!completed.await(100L, TimeUnit.MILLISECONDS)) 2056 context.throwIfCanceled(); 2057 } finally { 2058 if (!publisherTerminated.get()) { 2059 Flow.Subscription subscription = subscriptionRef.get(); 2060 2061 if (subscription != null) 2062 subscription.cancel(); 2063 } 2064 } 2065 2066 Throwable throwable = failure.get(); 2067 2068 if (throwable != null) { 2069 if (throwable instanceof Exception exception) 2070 throw exception; 2071 2072 if (throwable instanceof Error error) 2073 throw error; 2074 2075 throw new RuntimeException(throwable); 2076 } 2077 } 2078 2079 private void closeQuietly(@NonNull AutoCloseable closeable) { 2080 requireNonNull(closeable); 2081 2082 try { 2083 closeable.close(); 2084 } catch (InterruptedException e) { 2085 Thread.currentThread().interrupt(); 2086 } catch (Throwable ignored) { 2087 // Best effort only. The producer will observe cancelation separately. 2088 } 2089 } 2090 2091 @NonNull 2092 @Override 2093 public SseRequestResult performSseRequest(@NonNull Request request) { 2094 MockSseServer sseServer = getSseServer().orElse(null); 2095 2096 if (sseServer == null) 2097 throw new IllegalStateException(format("You must specify a %s in your %s to simulate Server-Sent Event requests", 2098 SseServer.class.getSimpleName(), SokletConfig.class.getSimpleName())); 2099 2100 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 2101 SseServer.RequestHandler requestHandler = sseServer.getRequestHandler().orElse(null); 2102 2103 if (requestHandler == null) 2104 throw new IllegalStateException("You must register a request handler prior to simulating SSE Event Source requests"); 2105 2106 requestHandler.handleRequest(request, (requestResult -> { 2107 requestResultHolder.set(requestResult); 2108 })); 2109 2110 HttpRequestResult requestResult = requestResultHolder.get(); 2111 SseHandshakeResult sseHandshakeResult = requestResult.getSseHandshakeResult().orElse(null); 2112 2113 if (sseHandshakeResult == null) 2114 return new SseRequestResult.RequestFailed(requestResult); 2115 2116 if (sseHandshakeResult instanceof SseHandshakeResult.Accepted acceptedHandshake) { 2117 Consumer<SseUnicaster> clientInitializer = acceptedHandshake.getClientInitializer().orElse(null); 2118 2119 // Create a synthetic logical response using values from the accepted handshake 2120 if (requestResult.getResponse().isEmpty()) 2121 requestResult = requestResult.copy() 2122 .response(Response.withStatusCode(200) 2123 .headers(acceptedHandshake.getHeaders()) 2124 .cookies(acceptedHandshake.getCookies()) 2125 .build()) 2126 .finish(); 2127 2128 HandshakeAccepted handshakeAccepted = new HandshakeAccepted(acceptedHandshake, request.getResourcePath(), requestResult, this, clientInitializer); 2129 return handshakeAccepted; 2130 } 2131 2132 if (sseHandshakeResult instanceof SseHandshakeResult.Rejected rejectedHandshake) 2133 return new HandshakeRejected(rejectedHandshake, requestResult); 2134 2135 throw new IllegalStateException(format("Encountered unexpected %s: %s", SseHandshakeResult.class.getSimpleName(), sseHandshakeResult)); 2136 } 2137 2138 @NonNull 2139 @Override 2140 public McpRequestResult performMcpRequest(@NonNull Request request) { 2141 requireNonNull(request); 2142 2143 MockMcpServer mcpServer = getMcpServer().orElse(null); 2144 2145 if (mcpServer == null) 2146 throw new IllegalStateException(format("You must specify an MCP server in your %s to simulate MCP requests", 2147 SokletConfig.class.getSimpleName())); 2148 2149 AtomicReference<HttpRequestResult> requestResultHolder = new AtomicReference<>(); 2150 McpServer.RequestHandler requestHandler = mcpServer.getRequestHandler().orElse(null); 2151 2152 if (requestHandler == null) 2153 throw new IllegalStateException("You must register a request handler prior to simulating MCP requests"); 2154 2155 requestHandler.handleRequest(request, requestResultHolder::set); 2156 2157 HttpRequestResult requestResult = requestResultHolder.get(); 2158 2159 if (requestResult == null) 2160 throw new IllegalStateException("No MCP request result was produced by the simulator"); 2161 2162 if (extractContentTypeFromHeaders(requestResult.getMarshaledResponse().getHeaders()) 2163 .filter(contentType -> contentType.equalsIgnoreCase("text/event-stream")) 2164 .isPresent()) { 2165 McpRequestResult.StreamOpened streamOpened = new McpRequestResult.StreamOpened( 2166 requestResult, 2167 mcpServer.getMcpStreamErrorHandler(), 2168 requestResult.isMcpStreamClosedAfterReplay()); 2169 2170 for (McpObject mcpStreamMessage : requestResult.getMcpStreamMessages()) 2171 streamOpened.emitMessage(mcpStreamMessage); 2172 2173 if (!requestResult.isMcpStreamClosedAfterReplay()) 2174 request.getHeader("MCP-Session-Id").ifPresent(sessionId -> mcpServer.registerOpenStream(sessionId, request, streamOpened)); 2175 2176 return streamOpened; 2177 } 2178 2179 if (request.getHttpMethod() == HttpMethod.DELETE 2180 && Objects.equals(requestResult.getMarshaledResponse().getStatusCode(), 204)) 2181 request.getHeader("MCP-Session-Id").ifPresent(mcpServer::terminateStreamsForSession); 2182 2183 return new McpRequestResult.ResponseCompleted(requestResult); 2184 } 2185 2186 @NonNull 2187 @Override 2188 public Simulator onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) { 2189 MockSseServer sseServer = getSseServer().orElse(null); 2190 2191 if (sseServer != null) 2192 sseServer.onBroadcastError(onBroadcastError); 2193 2194 return this; 2195 } 2196 2197 @NonNull 2198 @Override 2199 public Simulator onUnicastError(@Nullable Consumer<Throwable> onUnicastError) { 2200 MockSseServer sseServer = getSseServer().orElse(null); 2201 2202 if (sseServer != null) 2203 sseServer.onUnicastError(onUnicastError); 2204 2205 return this; 2206 } 2207 2208 @NonNull 2209 @Override 2210 public Simulator onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) { 2211 MockMcpServer mcpServer = getMcpServer().orElse(null); 2212 2213 if (mcpServer != null) 2214 mcpServer.onMcpStreamError(onMcpStreamError); 2215 2216 return this; 2217 } 2218 2219 @NonNull 2220 Optional<MockHttpServer> getHttpServer() { 2221 return Optional.ofNullable(this.server); 2222 } 2223 2224 @NonNull 2225 Optional<MockSseServer> getSseServer() { 2226 return Optional.ofNullable(this.sseServer); 2227 } 2228 2229 @NonNull 2230 Optional<MockMcpServer> getMcpServer() { 2231 return Optional.ofNullable(this.mcpServer); 2232 } 2233 2234 @NonNull 2235 SimulatorOptions getSimulatorOptions() { 2236 return this.simulatorOptions; 2237 } 2238 } 2239 2240 @NotThreadSafe 2241 private static final class SimulatorResponseStream implements ResponseStream { 2242 @NonNull 2243 private final ByteArrayOutputStream byteArrayOutputStream; 2244 @NonNull 2245 private final Integer limitInBytes; 2246 @NonNull 2247 private final SimulatorCancelationToken cancelationToken; 2248 private boolean closed; 2249 2250 private SimulatorResponseStream(@NonNull Integer limitInBytes, 2251 @NonNull SimulatorCancelationToken cancelationToken) { 2252 this.byteArrayOutputStream = new ByteArrayOutputStream(); 2253 this.limitInBytes = requireNonNull(limitInBytes); 2254 this.cancelationToken = requireNonNull(cancelationToken); 2255 } 2256 2257 @Override 2258 public void write(@NonNull byte[] bytes) throws IOException, StreamingResponseCanceledException { 2259 requireNonNull(bytes); 2260 write(ByteBuffer.wrap(bytes)); 2261 } 2262 2263 @Override 2264 public void write(@NonNull ByteBuffer byteBuffer) throws IOException, StreamingResponseCanceledException { 2265 requireNonNull(byteBuffer); 2266 this.cancelationToken.throwIfCanceled(); 2267 2268 if (this.closed) 2269 throw new StreamingResponseCanceledException(StreamTerminationReason.APPLICATION_CANCELED); 2270 2271 ByteBuffer source = byteBuffer.asReadOnlyBuffer(); 2272 int bytesToWrite = source.remaining(); 2273 2274 if ((long) this.byteArrayOutputStream.size() + bytesToWrite > this.limitInBytes) { 2275 this.cancelationToken.cancel(StreamTerminationReason.SIMULATOR_LIMIT_EXCEEDED, null); 2276 throw new StreamingResponseCanceledException(StreamTerminationReason.SIMULATOR_LIMIT_EXCEEDED); 2277 } 2278 2279 byte[] bytes = new byte[bytesToWrite]; 2280 source.get(bytes); 2281 this.byteArrayOutputStream.write(bytes); 2282 } 2283 2284 @Override 2285 public void flush() throws StreamingResponseCanceledException { 2286 this.cancelationToken.throwIfCanceled(); 2287 } 2288 2289 @Override 2290 @NonNull 2291 public Boolean isOpen() { 2292 return !this.closed && !this.cancelationToken.isCanceled(); 2293 } 2294 2295 @NonNull 2296 private byte[] toByteArray() { 2297 this.closed = true; 2298 return this.byteArrayOutputStream.toByteArray(); 2299 } 2300 } 2301 2302 @ThreadSafe 2303 private static final class SimulatorCancelationToken implements CancelationToken { 2304 @NonNull 2305 private final AtomicBoolean canceled; 2306 @NonNull 2307 private final CopyOnWriteArrayList<Runnable> callbacks; 2308 @NonNull 2309 private final Consumer<Throwable> callbackFailureConsumer; 2310 @Nullable 2311 private volatile StreamTerminationReason reason; 2312 @Nullable 2313 private volatile Throwable cause; 2314 2315 private SimulatorCancelationToken(@NonNull Consumer<Throwable> callbackFailureConsumer) { 2316 this.canceled = new AtomicBoolean(false); 2317 this.callbacks = new CopyOnWriteArrayList<>(); 2318 this.callbackFailureConsumer = requireNonNull(callbackFailureConsumer); 2319 } 2320 2321 @Override 2322 @NonNull 2323 public Boolean isCanceled() { 2324 return this.canceled.get(); 2325 } 2326 2327 @Override 2328 @NonNull 2329 public Optional<StreamTerminationReason> getCancelationReason() { 2330 return Optional.ofNullable(this.reason); 2331 } 2332 2333 @Override 2334 @NonNull 2335 public Optional<Throwable> getCancelationCause() { 2336 return Optional.ofNullable(this.cause); 2337 } 2338 2339 @Override 2340 @NonNull 2341 public AutoCloseable onCancel(@NonNull Runnable callback) { 2342 requireNonNull(callback); 2343 2344 boolean runImmediately; 2345 2346 synchronized (this) { 2347 runImmediately = this.canceled.get(); 2348 2349 if (!runImmediately) 2350 this.callbacks.add(callback); 2351 } 2352 2353 if (runImmediately) { 2354 runCallback(callback); 2355 return () -> { 2356 // No-op 2357 }; 2358 } 2359 2360 return () -> { 2361 synchronized (this) { 2362 this.callbacks.remove(callback); 2363 } 2364 }; 2365 } 2366 2367 private boolean cancel(@NonNull StreamTerminationReason reason, 2368 @Nullable Throwable cause) { 2369 requireNonNull(reason); 2370 2371 if (reason == StreamTerminationReason.COMPLETED) 2372 throw new IllegalArgumentException("Cancelation reason cannot be COMPLETED"); 2373 2374 List<Runnable> callbacksToRun; 2375 2376 synchronized (this) { 2377 if (this.canceled.get()) 2378 return false; 2379 2380 this.reason = reason; 2381 this.cause = cause; 2382 this.canceled.set(true); 2383 callbacksToRun = List.copyOf(this.callbacks); 2384 this.callbacks.clear(); 2385 } 2386 2387 for (Runnable callback : callbacksToRun) 2388 runCallback(callback); 2389 2390 return true; 2391 } 2392 2393 private void runCallback(@NonNull Runnable callback) { 2394 requireNonNull(callback); 2395 2396 try { 2397 callback.run(); 2398 } catch (Throwable t) { 2399 this.callbackFailureConsumer.accept(t); 2400 } 2401 } 2402 } 2403 2404 @ThreadSafe 2405 private static final class SimulatorStreamingResponseContext implements StreamingResponseContext { 2406 @NonNull 2407 private final Request request; 2408 @NonNull 2409 private final CancelationToken cancelationToken; 2410 2411 private SimulatorStreamingResponseContext(@NonNull Request request, 2412 @NonNull CancelationToken cancelationToken) { 2413 this.request = requireNonNull(request); 2414 this.cancelationToken = requireNonNull(cancelationToken); 2415 } 2416 2417 @Override 2418 @NonNull 2419 public CancelationToken getCancelationToken() { 2420 return this.cancelationToken; 2421 } 2422 2423 @Override 2424 @NonNull 2425 public Request getRequest() { 2426 return this.request; 2427 } 2428 2429 @Override 2430 @NonNull 2431 public Optional<Instant> getDeadline() { 2432 return Optional.empty(); 2433 } 2434 2435 @Override 2436 @NonNull 2437 public Optional<Duration> getIdleTimeout() { 2438 return Optional.empty(); 2439 } 2440 } 2441 2442 /** 2443 * Mock server that doesn't touch the network at all, useful for testing. 2444 * 2445 * @author <a href="https://www.revetkn.com">Mark Allen</a> 2446 */ 2447 @ThreadSafe 2448 static class MockHttpServer implements HttpServer { 2449 @Nullable 2450 private SokletConfig sokletConfig; 2451 private HttpServer.@Nullable RequestHandler requestHandler; 2452 2453 @Override 2454 public void start() { 2455 // No-op 2456 } 2457 2458 @Override 2459 public void stop() { 2460 // No-op 2461 } 2462 2463 @NonNull 2464 @Override 2465 public Boolean isStarted() { 2466 return true; 2467 } 2468 2469 @Override 2470 public void initialize(@NonNull SokletConfig sokletConfig, 2471 @NonNull RequestHandler requestHandler) { 2472 requireNonNull(sokletConfig); 2473 requireNonNull(requestHandler); 2474 2475 this.sokletConfig = sokletConfig; 2476 this.requestHandler = requestHandler; 2477 } 2478 2479 @NonNull 2480 protected Optional<SokletConfig> getSokletConfig() { 2481 return Optional.ofNullable(this.sokletConfig); 2482 } 2483 2484 @NonNull 2485 protected Optional<RequestHandler> getRequestHandler() { 2486 return Optional.ofNullable(this.requestHandler); 2487 } 2488 } 2489 2490 /** 2491 * Mock MCP server that doesn't touch the network at all, useful for testing. 2492 */ 2493 @ThreadSafe 2494 static class MockMcpServer implements McpServer, InternalMcpSessionMessagePublisher { 2495 @NonNull 2496 private final McpServer realImplementation; 2497 @Nullable 2498 private SokletConfig sokletConfig; 2499 private McpServer.@Nullable RequestHandler requestHandler; 2500 @NonNull 2501 private final AtomicReference<Consumer<Throwable>> mcpStreamErrorHandler; 2502 @NonNull 2503 private final ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> openStreamsBySessionId; 2504 @NonNull 2505 private final AtomicReference<@Nullable BiConsumer<Request, String>> clientDisconnectedMcpStreamHandler; 2506 2507 public MockMcpServer(@NonNull McpServer realImplementation) { 2508 requireNonNull(realImplementation); 2509 2510 this.realImplementation = realImplementation; 2511 this.mcpStreamErrorHandler = new AtomicReference<>(); 2512 this.openStreamsBySessionId = new ConcurrentHashMap<>(); 2513 this.clientDisconnectedMcpStreamHandler = new AtomicReference<>(); 2514 } 2515 2516 @Override 2517 public void start() { 2518 // No-op 2519 } 2520 2521 @Override 2522 public void stop() { 2523 // No-op 2524 } 2525 2526 @NonNull 2527 @Override 2528 public Boolean isStarted() { 2529 return true; 2530 } 2531 2532 @Override 2533 public void initialize(@NonNull SokletConfig sokletConfig, 2534 @NonNull RequestHandler requestHandler) { 2535 requireNonNull(sokletConfig); 2536 requireNonNull(requestHandler); 2537 2538 this.sokletConfig = sokletConfig; 2539 this.requestHandler = requestHandler; 2540 } 2541 2542 @NonNull 2543 @Override 2544 public McpHandlerResolver getHandlerResolver() { 2545 return getRealImplementation().getHandlerResolver(); 2546 } 2547 2548 @NonNull 2549 @Override 2550 public McpRequestAdmissionPolicy getRequestAdmissionPolicy() { 2551 return getRealImplementation().getRequestAdmissionPolicy(); 2552 } 2553 2554 @NonNull 2555 @Override 2556 public McpRequestInterceptor getRequestInterceptor() { 2557 return getRealImplementation().getRequestInterceptor(); 2558 } 2559 2560 @NonNull 2561 @Override 2562 public McpResponseMarshaler getResponseMarshaler() { 2563 return getRealImplementation().getResponseMarshaler(); 2564 } 2565 2566 @NonNull 2567 @Override 2568 public McpCorsAuthorizer getCorsAuthorizer() { 2569 return getRealImplementation().getCorsAuthorizer(); 2570 } 2571 2572 @NonNull 2573 @Override 2574 public McpSessionStore getSessionStore() { 2575 return getRealImplementation().getSessionStore(); 2576 } 2577 2578 @NonNull 2579 @Override 2580 public IdGenerator<String> getSessionIdGenerator() { 2581 return getRealImplementation().getSessionIdGenerator(); 2582 } 2583 2584 @NonNull 2585 protected McpServer getRealImplementation() { 2586 return this.realImplementation; 2587 } 2588 2589 @NonNull 2590 protected Optional<SokletConfig> getSokletConfig() { 2591 return Optional.ofNullable(this.sokletConfig); 2592 } 2593 2594 @NonNull 2595 protected Optional<RequestHandler> getRequestHandler() { 2596 return Optional.ofNullable(this.requestHandler); 2597 } 2598 2599 protected void onMcpStreamError(@Nullable Consumer<Throwable> onMcpStreamError) { 2600 this.mcpStreamErrorHandler.set(onMcpStreamError); 2601 } 2602 2603 @NonNull 2604 protected AtomicReference<Consumer<Throwable>> getMcpStreamErrorHandler() { 2605 return this.mcpStreamErrorHandler; 2606 } 2607 2608 protected void registerOpenStream(@NonNull String sessionId, 2609 @NonNull Request request, 2610 McpRequestResult.StreamOpened streamOpened) { 2611 requireNonNull(sessionId); 2612 requireNonNull(request); 2613 requireNonNull(streamOpened); 2614 2615 getOpenStreamsBySessionId() 2616 .computeIfAbsent(sessionId, ignored -> new CopyOnWriteArrayList<>()) 2617 .add(streamOpened); 2618 2619 streamOpened.onClose(() -> closeOpenStream(sessionId, request, streamOpened)); 2620 } 2621 2622 protected void terminateStreamsForSession(@NonNull String sessionId) { 2623 requireNonNull(sessionId); 2624 2625 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().remove(sessionId); 2626 2627 if (streams == null) 2628 return; 2629 2630 for (McpRequestResult.StreamOpened streamOpened : streams) 2631 streamOpened.terminate(); 2632 } 2633 2634 @NonNull 2635 @Override 2636 public Boolean publishSessionMessage(@NonNull String sessionId, 2637 @NonNull McpObject message) { 2638 requireNonNull(sessionId); 2639 requireNonNull(message); 2640 2641 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId); 2642 2643 if (streams == null || streams.isEmpty()) 2644 return false; 2645 2646 for (int i = streams.size() - 1; i >= 0; i--) { 2647 McpRequestResult.StreamOpened streamOpened = streams.get(i); 2648 2649 if (streamOpened.isClosed()) 2650 continue; 2651 2652 streamOpened.emitMessage(message); 2653 return true; 2654 } 2655 2656 return false; 2657 } 2658 2659 protected void onClientDisconnectedMcpStream(@Nullable BiConsumer<Request, String> clientDisconnectedMcpStreamHandler) { 2660 this.clientDisconnectedMcpStreamHandler.set(clientDisconnectedMcpStreamHandler); 2661 } 2662 2663 protected void closeOpenStream(@NonNull String sessionId, 2664 @NonNull Request request, 2665 McpRequestResult.StreamOpened streamOpened) { 2666 requireNonNull(sessionId); 2667 requireNonNull(request); 2668 requireNonNull(streamOpened); 2669 2670 CopyOnWriteArrayList<McpRequestResult.StreamOpened> streams = getOpenStreamsBySessionId().get(sessionId); 2671 2672 if (streams != null) { 2673 streams.remove(streamOpened); 2674 2675 if (streams.isEmpty()) 2676 getOpenStreamsBySessionId().remove(sessionId, streams); 2677 } 2678 2679 BiConsumer<Request, String> handler = this.clientDisconnectedMcpStreamHandler.get(); 2680 2681 if (handler != null) 2682 handler.accept(request, sessionId); 2683 } 2684 2685 @NonNull 2686 protected ConcurrentHashMap<@NonNull String, @NonNull CopyOnWriteArrayList<McpRequestResult.StreamOpened>> getOpenStreamsBySessionId() { 2687 return this.openStreamsBySessionId; 2688 } 2689 } 2690 2691 /** 2692 * Mock Server-Sent Event unicaster that doesn't touch the network at all, useful for testing. 2693 */ 2694 @ThreadSafe 2695 static class MockSseUnicaster implements SseUnicaster { 2696 @NonNull 2697 private final ResourcePath resourcePath; 2698 @NonNull 2699 private final Consumer<SseEvent> eventConsumer; 2700 @NonNull 2701 private final Consumer<SseComment> commentConsumer; 2702 @NonNull 2703 private final AtomicReference<Consumer<Throwable>> unicastErrorHandler; 2704 2705 public MockSseUnicaster(@NonNull ResourcePath resourcePath, 2706 @NonNull Consumer<SseEvent> eventConsumer, 2707 @NonNull Consumer<SseComment> commentConsumer, 2708 @NonNull AtomicReference<Consumer<Throwable>> unicastErrorHandler) { 2709 requireNonNull(resourcePath); 2710 requireNonNull(eventConsumer); 2711 requireNonNull(commentConsumer); 2712 requireNonNull(unicastErrorHandler); 2713 2714 this.resourcePath = resourcePath; 2715 this.eventConsumer = eventConsumer; 2716 this.commentConsumer = commentConsumer; 2717 this.unicastErrorHandler = unicastErrorHandler; 2718 } 2719 2720 @Override 2721 public void unicastEvent(@NonNull SseEvent sseEvent) { 2722 requireNonNull(sseEvent); 2723 try { 2724 getEventConsumer().accept(sseEvent); 2725 } catch (Throwable throwable) { 2726 handleUnicastError(throwable); 2727 } 2728 } 2729 2730 @Override 2731 public void unicastComment(@NonNull SseComment sseComment) { 2732 requireNonNull(sseComment); 2733 try { 2734 getCommentConsumer().accept(sseComment); 2735 } catch (Throwable throwable) { 2736 handleUnicastError(throwable); 2737 } 2738 } 2739 2740 @NonNull 2741 @Override 2742 public ResourcePath getResourcePath() { 2743 return this.resourcePath; 2744 } 2745 2746 @NonNull 2747 protected Consumer<SseEvent> getEventConsumer() { 2748 return this.eventConsumer; 2749 } 2750 2751 @NonNull 2752 protected Consumer<SseComment> getCommentConsumer() { 2753 return this.commentConsumer; 2754 } 2755 2756 protected void handleUnicastError(@NonNull Throwable throwable) { 2757 requireNonNull(throwable); 2758 Consumer<Throwable> handler = this.unicastErrorHandler.get(); 2759 2760 if (handler != null) { 2761 try { 2762 handler.accept(throwable); 2763 return; 2764 } catch (Throwable ignored) { 2765 // Fall through to default behavior 2766 } 2767 } 2768 2769 throwable.printStackTrace(); 2770 } 2771 } 2772 2773 /** 2774 * Mock Server-Sent Event broadcaster that doesn't touch the network at all, useful for testing. 2775 */ 2776 @ThreadSafe 2777 static class MockSseBroadcaster implements SseBroadcaster { 2778 // ConcurrentHashMap doesn't allow null values, so we use a sentinel if context is null 2779 private static final Object NULL_CONTEXT_SENTINEL; 2780 2781 static { 2782 NULL_CONTEXT_SENTINEL = new Object(); 2783 } 2784 2785 @NonNull 2786 private final ResourcePath resourcePath; 2787 // Maps the Consumer (Listener) to its Context object (e.g. Locale) 2788 @NonNull 2789 private final Map<@NonNull Consumer<SseEvent>, @NonNull Object> eventConsumers; 2790 // Same goes for comments 2791 @NonNull 2792 private final Map<@NonNull Consumer<SseComment>, @NonNull Object> commentConsumers; 2793 @NonNull 2794 private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler; 2795 2796 public MockSseBroadcaster(@NonNull ResourcePath resourcePath, 2797 @NonNull AtomicReference<Consumer<Throwable>> broadcastErrorHandler) { 2798 requireNonNull(resourcePath); 2799 requireNonNull(broadcastErrorHandler); 2800 2801 this.resourcePath = resourcePath; 2802 this.eventConsumers = new ConcurrentHashMap<>(); 2803 this.commentConsumers = new ConcurrentHashMap<>(); 2804 this.broadcastErrorHandler = broadcastErrorHandler; 2805 } 2806 2807 @NonNull 2808 @Override 2809 public ResourcePath getResourcePath() { 2810 return this.resourcePath; 2811 } 2812 2813 @NonNull 2814 @Override 2815 public Long getClientCount() { 2816 return Long.valueOf(getEventConsumers().size() + getCommentConsumers().size()); 2817 } 2818 2819 @Override 2820 public void broadcastEvent(@NonNull SseEvent sseEvent) { 2821 requireNonNull(sseEvent); 2822 2823 for (Consumer<SseEvent> eventConsumer : getEventConsumers().keySet()) { 2824 try { 2825 eventConsumer.accept(sseEvent); 2826 } catch (Throwable throwable) { 2827 handleBroadcastError(throwable); 2828 } 2829 } 2830 } 2831 2832 @Override 2833 public void broadcastComment(@NonNull SseComment sseComment) { 2834 requireNonNull(sseComment); 2835 2836 for (Consumer<SseComment> commentConsumer : getCommentConsumers().keySet()) { 2837 try { 2838 commentConsumer.accept(sseComment); 2839 } catch (Throwable throwable) { 2840 handleBroadcastError(throwable); 2841 } 2842 } 2843 } 2844 2845 @Override 2846 public <T> void broadcastEvent( 2847 @NonNull Function<Object, T> keySelector, 2848 @NonNull Function<T, SseEvent> eventProvider 2849 ) { 2850 requireNonNull(keySelector); 2851 requireNonNull(eventProvider); 2852 2853 // 1. Create a temporary cache for this specific broadcast operation. 2854 // This ensures we only run the expensive 'eventProvider' once per unique key. 2855 Map<T, SseEvent> payloadCache = new HashMap<>(); 2856 2857 this.getEventConsumers().forEach((consumer, context) -> { 2858 try { 2859 // 2. Derive the key from the subscriber's context 2860 T key = keySelector.apply(context); 2861 2862 // 3. Memoize: Generate the payload if we haven't seen this key yet, otherwise reuse it 2863 SseEvent event = payloadCache.computeIfAbsent(key, eventProvider); 2864 2865 // 4. Dispatch 2866 consumer.accept(event); 2867 } catch (Throwable throwable) { 2868 handleBroadcastError(throwable); 2869 } 2870 }); 2871 } 2872 2873 @Override 2874 public <T> void broadcastComment( 2875 @NonNull Function<Object, T> keySelector, 2876 @NonNull Function<T, SseComment> commentProvider 2877 ) { 2878 requireNonNull(keySelector); 2879 requireNonNull(commentProvider); 2880 2881 // 1. Create temporary cache 2882 Map<T, SseComment> commentCache = new HashMap<>(); 2883 2884 this.getCommentConsumers().forEach((consumer, context) -> { 2885 try { 2886 // 2. Derive key 2887 T key = keySelector.apply(context); 2888 2889 // 3. Memoize 2890 SseComment comment = commentCache.computeIfAbsent(key, commentProvider); 2891 2892 // 4. Dispatch 2893 consumer.accept(comment); 2894 } catch (Throwable throwable) { 2895 handleBroadcastError(throwable); 2896 } 2897 }); 2898 } 2899 2900 @NonNull 2901 public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) { 2902 return registerEventConsumer(eventConsumer, null); 2903 } 2904 2905 /** 2906 * Registers a consumer with an associated context, simulating a client with specific traits. 2907 */ 2908 @NonNull 2909 public Boolean registerEventConsumer(@NonNull Consumer<SseEvent> eventConsumer, @Nullable Object context) { 2910 requireNonNull(eventConsumer); 2911 // map.put returns null if the key was new, which conceptually matches "add" returning true 2912 return this.getEventConsumers().put(eventConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null; 2913 } 2914 2915 @NonNull 2916 public Boolean unregisterEventConsumer(@NonNull Consumer<SseEvent> eventConsumer) { 2917 requireNonNull(eventConsumer); 2918 return this.getEventConsumers().remove(eventConsumer) != null; 2919 } 2920 2921 @NonNull 2922 public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) { 2923 return registerCommentConsumer(commentConsumer, null); 2924 } 2925 2926 /** 2927 * Registers a consumer with an associated context, simulating a client with specific traits. 2928 */ 2929 @NonNull 2930 public Boolean registerCommentConsumer(@NonNull Consumer<SseComment> commentConsumer, @Nullable Object context) { 2931 requireNonNull(commentConsumer); 2932 return this.getCommentConsumers().put(commentConsumer, context == null ? NULL_CONTEXT_SENTINEL : context) == null; 2933 } 2934 2935 @NonNull 2936 public Boolean unregisterCommentConsumer(@NonNull Consumer<SseComment> commentConsumer) { 2937 requireNonNull(commentConsumer); 2938 return this.getCommentConsumers().remove(commentConsumer) != null; 2939 } 2940 2941 @NonNull 2942 protected Map<@NonNull Consumer<SseEvent>, @NonNull Object> getEventConsumers() { 2943 return this.eventConsumers; 2944 } 2945 2946 @NonNull 2947 protected Map<@NonNull Consumer<SseComment>, @NonNull Object> getCommentConsumers() { 2948 return this.commentConsumers; 2949 } 2950 2951 protected void handleBroadcastError(@NonNull Throwable throwable) { 2952 requireNonNull(throwable); 2953 Consumer<Throwable> handler = this.broadcastErrorHandler.get(); 2954 2955 if (handler != null) { 2956 try { 2957 handler.accept(throwable); 2958 return; 2959 } catch (Throwable ignored) { 2960 // Fall through to default behavior 2961 } 2962 } 2963 2964 throwable.printStackTrace(); 2965 } 2966 } 2967 2968 /** 2969 * Mock Server-Sent Event server that doesn't touch the network at all, useful for testing. 2970 * 2971 * @author <a href="https://www.revetkn.com">Mark Allen</a> 2972 */ 2973 @ThreadSafe 2974 static class MockSseServer implements SseServer { 2975 @Nullable 2976 private SokletConfig sokletConfig; 2977 private SseServer.@Nullable RequestHandler requestHandler; 2978 @NonNull 2979 private final ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> broadcastersByResourcePath; 2980 @NonNull 2981 private final AtomicReference<Consumer<Throwable>> broadcastErrorHandler; 2982 @NonNull 2983 private final AtomicReference<Consumer<Throwable>> unicastErrorHandler; 2984 2985 public MockSseServer() { 2986 this.broadcastersByResourcePath = new ConcurrentHashMap<>(); 2987 this.broadcastErrorHandler = new AtomicReference<>(); 2988 this.unicastErrorHandler = new AtomicReference<>(); 2989 } 2990 2991 @Override 2992 public void start() { 2993 // No-op 2994 } 2995 2996 @Override 2997 public void stop() { 2998 // No-op 2999 } 3000 3001 @NonNull 3002 @Override 3003 public Boolean isStarted() { 3004 return true; 3005 } 3006 3007 @NonNull 3008 @Override 3009 public Optional<? extends SseBroadcaster> acquireBroadcaster(@Nullable ResourcePath resourcePath) { 3010 if (resourcePath == null) 3011 return Optional.empty(); 3012 3013 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 3014 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 3015 3016 return Optional.of(broadcaster); 3017 } 3018 3019 public void registerEventConsumer(@NonNull ResourcePath resourcePath, 3020 @NonNull Consumer<SseEvent> eventConsumer) { 3021 registerEventConsumer(resourcePath, eventConsumer, null); 3022 } 3023 3024 public void registerEventConsumer(@NonNull ResourcePath resourcePath, 3025 @NonNull Consumer<SseEvent> eventConsumer, 3026 @Nullable Object context) { 3027 requireNonNull(resourcePath); 3028 requireNonNull(eventConsumer); 3029 3030 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 3031 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 3032 3033 broadcaster.registerEventConsumer(eventConsumer, context); 3034 } 3035 3036 @NonNull 3037 public Boolean unregisterEventConsumer(@NonNull ResourcePath resourcePath, 3038 @NonNull Consumer<SseEvent> eventConsumer) { 3039 requireNonNull(resourcePath); 3040 requireNonNull(eventConsumer); 3041 3042 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath); 3043 3044 if (broadcaster == null) 3045 return false; 3046 3047 return broadcaster.unregisterEventConsumer(eventConsumer); 3048 } 3049 3050 public void registerCommentConsumer(@NonNull ResourcePath resourcePath, 3051 @NonNull Consumer<SseComment> commentConsumer) { 3052 registerCommentConsumer(resourcePath, commentConsumer, null); 3053 } 3054 3055 public void registerCommentConsumer(@NonNull ResourcePath resourcePath, 3056 @NonNull Consumer<SseComment> commentConsumer, 3057 @Nullable Object context) { 3058 requireNonNull(resourcePath); 3059 requireNonNull(commentConsumer); 3060 3061 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath() 3062 .computeIfAbsent(resourcePath, rp -> new MockSseBroadcaster(rp, broadcastErrorHandler)); 3063 3064 broadcaster.registerCommentConsumer(commentConsumer, context); 3065 } 3066 3067 @NonNull 3068 public Boolean unregisterCommentConsumer(@NonNull ResourcePath resourcePath, 3069 @NonNull Consumer<SseComment> commentConsumer) { 3070 requireNonNull(resourcePath); 3071 requireNonNull(commentConsumer); 3072 3073 MockSseBroadcaster broadcaster = getBroadcastersByResourcePath().get(resourcePath); 3074 3075 if (broadcaster == null) 3076 return false; 3077 3078 return broadcaster.unregisterCommentConsumer(commentConsumer); 3079 } 3080 3081 @Override 3082 public void initialize(@NonNull SokletConfig sokletConfig, 3083 SseServer.@NonNull RequestHandler requestHandler) { 3084 requireNonNull(sokletConfig); 3085 requireNonNull(requestHandler); 3086 3087 this.sokletConfig = sokletConfig; 3088 this.requestHandler = requestHandler; 3089 } 3090 3091 public void onBroadcastError(@Nullable Consumer<Throwable> onBroadcastError) { 3092 this.broadcastErrorHandler.set(onBroadcastError); 3093 } 3094 3095 public void onUnicastError(@Nullable Consumer<Throwable> onUnicastError) { 3096 this.unicastErrorHandler.set(onUnicastError); 3097 } 3098 3099 @NonNull 3100 protected Optional<SokletConfig> getSokletConfig() { 3101 return Optional.ofNullable(this.sokletConfig); 3102 } 3103 3104 @NonNull 3105 protected Optional<SseServer.RequestHandler> getRequestHandler() { 3106 return Optional.ofNullable(this.requestHandler); 3107 } 3108 3109 @NonNull 3110 protected ConcurrentHashMap<@NonNull ResourcePath, @NonNull MockSseBroadcaster> getBroadcastersByResourcePath() { 3111 return this.broadcastersByResourcePath; 3112 } 3113 3114 @NonNull 3115 protected AtomicReference<Consumer<Throwable>> getUnicastErrorHandler() { 3116 return this.unicastErrorHandler; 3117 } 3118 } 3119 3120}