001/* 002 * Copyright 2022-2026 Revetware LLC. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.soklet; 018 019import org.jspecify.annotations.NonNull; 020 021import javax.annotation.concurrent.ThreadSafe; 022import java.net.URI; 023import java.time.Duration; 024import java.util.LinkedHashSet; 025import java.util.Locale; 026import java.util.Optional; 027import java.util.Set; 028import java.util.function.Function; 029import java.util.function.Predicate; 030 031import static java.util.Objects.requireNonNull; 032 033/** 034 * CORS authorization contract for MCP transport requests. 035 * 036 * @author <a href="https://www.revetkn.com">Mark Allen</a> 037 */ 038@ThreadSafe 039public interface McpCorsAuthorizer { 040 /** 041 * Authorizes a non-preflight browser-originated MCP request and, when allowed, supplies the CORS response metadata to apply. 042 * 043 * @param context the MCP CORS context 044 * @param cors the parsed Soklet CORS request metadata 045 * @return the CORS response metadata to apply, or {@link Optional#empty()} to withhold CORS authorization 046 */ 047 @NonNull 048 Optional<CorsResponse> authorize(@NonNull McpCorsContext context, 049 @NonNull Cors cors); 050 051 /** 052 * Authorizes a browser preflight request for the MCP transport. 053 * 054 * @param context the MCP CORS context 055 * @param corsPreflight the parsed preflight metadata 056 * @param availableHttpMethods the MCP transport methods available for the current endpoint 057 * @return the preflight response metadata to apply, or {@link Optional#empty()} to reject the preflight 058 */ 059 @NonNull 060 Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context, 061 @NonNull CorsPreflight corsPreflight, 062 @NonNull Set<@NonNull HttpMethod> availableHttpMethods); 063 064 /** 065 * Acquires an authorizer that rejects all browser-originated MCP CORS requests. 066 * 067 * @return a rejecting authorizer 068 */ 069 @NonNull 070 static McpCorsAuthorizer rejectAllInstance() { 071 return new McpCorsAuthorizer() { 072 @NonNull 073 @Override 074 public Optional<CorsResponse> authorize(@NonNull McpCorsContext context, 075 @NonNull Cors cors) { 076 requireNonNull(context); 077 requireNonNull(cors); 078 return Optional.empty(); 079 } 080 081 @NonNull 082 @Override 083 public Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context, 084 @NonNull CorsPreflight corsPreflight, 085 @NonNull Set<@NonNull HttpMethod> availableHttpMethods) { 086 requireNonNull(context); 087 requireNonNull(corsPreflight); 088 requireNonNull(availableHttpMethods); 089 return Optional.empty(); 090 } 091 }; 092 } 093 094 /** 095 * Acquires the conservative default authorizer that leaves non-browser MCP requests alone while rejecting browser CORS authorization. 096 * 097 * @return the default non-browser-only authorizer 098 */ 099 @NonNull 100 static McpCorsAuthorizer nonBrowserClientsOnlyInstance() { 101 return rejectAllInstance(); 102 } 103 104 /** 105 * Acquires an authorizer that allows all browser origins and always enables credentials. 106 * 107 * @return a permissive authorizer 108 */ 109 @NonNull 110 static McpCorsAuthorizer acceptAllInstance() { 111 return fromOriginAuthorizer(context -> true, origin -> true); 112 } 113 114 /** 115 * Acquires an authorizer that allows only the provided normalized origins and disables credentials by default. 116 * 117 * @param whitelistedOrigins the origins to allow 118 * @return a whitelisting authorizer 119 */ 120 @NonNull 121 static McpCorsAuthorizer fromWhitelistedOrigins(@NonNull Set<@NonNull String> whitelistedOrigins) { 122 return fromWhitelistedOrigins(whitelistedOrigins, origin -> false); 123 } 124 125 /** 126 * Acquires an authorizer that allows only the provided normalized origins and delegates credential behavior per origin. 127 * 128 * @param whitelistedOrigins the origins to allow 129 * @param allowCredentialsResolver resolves whether credentials should be allowed for an origin 130 * @return a whitelisting authorizer 131 */ 132 @NonNull 133 static McpCorsAuthorizer fromWhitelistedOrigins(@NonNull Set<@NonNull String> whitelistedOrigins, 134 @NonNull Function<String, Boolean> allowCredentialsResolver) { 135 requireNonNull(whitelistedOrigins); 136 requireNonNull(allowCredentialsResolver); 137 138 Set<String> normalizedOrigins = new LinkedHashSet<>(); 139 140 for (String whitelistedOrigin : whitelistedOrigins) { 141 requireNonNull(whitelistedOrigin); 142 normalizedOrigins.add(normalizeOrigin(whitelistedOrigin)); 143 } 144 145 return fromOriginAuthorizer(context -> { 146 requireNonNull(context); 147 148 if (context.origin() == null) 149 return false; 150 151 return normalizedOrigins.contains(normalizeOrigin(context.origin())); 152 }, allowCredentialsResolver); 153 } 154 155 /** 156 * Acquires an authorizer backed by an origin-authorization predicate and with credentials disabled by default. 157 * 158 * @param originAuthorizer the origin predicate 159 * @return an authorizer backed by the predicate 160 */ 161 @NonNull 162 static McpCorsAuthorizer fromOriginAuthorizer(@NonNull Predicate<@NonNull McpCorsContext> originAuthorizer) { 163 return fromOriginAuthorizer(originAuthorizer, origin -> false); 164 } 165 166 /** 167 * Acquires an authorizer backed by an origin-authorization predicate plus a credentials resolver. 168 * 169 * @param originAuthorizer the origin predicate 170 * @param allowCredentialsResolver resolves whether credentials should be allowed for an origin 171 * @return an authorizer backed by the supplied callbacks 172 */ 173 @NonNull 174 static McpCorsAuthorizer fromOriginAuthorizer(@NonNull Predicate<@NonNull McpCorsContext> originAuthorizer, 175 @NonNull Function<String, Boolean> allowCredentialsResolver) { 176 requireNonNull(originAuthorizer); 177 requireNonNull(allowCredentialsResolver); 178 179 return new McpCorsAuthorizer() { 180 @NonNull 181 @Override 182 public Optional<CorsResponse> authorize(@NonNull McpCorsContext context, 183 @NonNull Cors cors) { 184 requireNonNull(context); 185 requireNonNull(cors); 186 187 if (!originAuthorizer.test(context)) 188 return Optional.empty(); 189 190 return Optional.of(CorsResponse.withAccessControlAllowOrigin(cors.getOrigin()) 191 .accessControlAllowCredentials(allowCredentialsResolver.apply(normalizeOrigin(cors.getOrigin()))) 192 .accessControlExposeHeaders(defaultExposedHeaders()) 193 .build()); 194 } 195 196 @NonNull 197 @Override 198 public Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context, 199 @NonNull CorsPreflight corsPreflight, 200 @NonNull Set<@NonNull HttpMethod> availableHttpMethods) { 201 requireNonNull(context); 202 requireNonNull(corsPreflight); 203 requireNonNull(availableHttpMethods); 204 205 if (!originAuthorizer.test(context)) 206 return Optional.empty(); 207 208 return Optional.of(CorsPreflightResponse.withAccessControlAllowOrigin(corsPreflight.getOrigin()) 209 .accessControlAllowMethods(availableHttpMethods) 210 .accessControlAllowHeaders(corsPreflight.getAccessControlRequestHeaders()) 211 .accessControlAllowCredentials(allowCredentialsResolver.apply(normalizeOrigin(corsPreflight.getOrigin()))) 212 .accessControlMaxAge(Duration.ofMinutes(10)) 213 .build()); 214 } 215 }; 216 } 217 218 @NonNull 219 private static Set<String> defaultExposedHeaders() { 220 return Set.of("MCP-Session-Id", "WWW-Authenticate"); 221 } 222 223 @NonNull 224 private static String normalizeOrigin(@NonNull String origin) { 225 requireNonNull(origin); 226 227 if ("null".equals(origin)) 228 return "null"; 229 230 URI uri = URI.create(origin.trim()); 231 String scheme = uri.getScheme() == null ? "" : uri.getScheme().toLowerCase(Locale.ROOT); 232 String host = uri.getHost() == null ? "" : uri.getHost().toLowerCase(Locale.ROOT); 233 Integer port = uri.getPort() == -1 ? null : uri.getPort(); 234 235 if (("http".equals(scheme) && Integer.valueOf(80).equals(port)) 236 || ("https".equals(scheme) && Integer.valueOf(443).equals(port))) 237 port = null; 238 239 return port == null ? "%s://%s".formatted(scheme, host) : "%s://%s:%s".formatted(scheme, host, port); 240 } 241}