001/*
002 * Copyright 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet;
018
019import org.jspecify.annotations.NonNull;
020
021import javax.annotation.concurrent.ThreadSafe;
022import java.net.URI;
023import java.time.Duration;
024import java.util.LinkedHashSet;
025import java.util.Locale;
026import java.util.Optional;
027import java.util.Set;
028import java.util.function.Function;
029import java.util.function.Predicate;
030
031import static java.util.Objects.requireNonNull;
032
033/**
034 * CORS authorization contract for MCP transport requests.
035 *
036 * @author <a href="https://www.revetkn.com">Mark Allen</a>
037 */
038@ThreadSafe
039public interface McpCorsAuthorizer {
040        /**
041         * Authorizes a non-preflight browser-originated MCP request and, when allowed, supplies the CORS response metadata to apply.
042         *
043         * @param context the MCP CORS context
044         * @param cors the parsed Soklet CORS request metadata
045         * @return the CORS response metadata to apply, or {@link Optional#empty()} to withhold CORS authorization
046         */
047        @NonNull
048        Optional<CorsResponse> authorize(@NonNull McpCorsContext context,
049                                                                                                                                         @NonNull Cors cors);
050
051        /**
052         * Authorizes a browser preflight request for the MCP transport.
053         *
054         * @param context the MCP CORS context
055         * @param corsPreflight the parsed preflight metadata
056         * @param availableHttpMethods the MCP transport methods available for the current endpoint
057         * @return the preflight response metadata to apply, or {@link Optional#empty()} to reject the preflight
058         */
059        @NonNull
060        Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context,
061                                                                                                                                                                                                                 @NonNull CorsPreflight corsPreflight,
062                                                                                                                                                                                                                 @NonNull Set<@NonNull HttpMethod> availableHttpMethods);
063
064        /**
065         * Acquires an authorizer that rejects all browser-originated MCP CORS requests.
066         *
067         * @return a rejecting authorizer
068         */
069        @NonNull
070        static McpCorsAuthorizer rejectAllInstance() {
071                return new McpCorsAuthorizer() {
072                        @NonNull
073                        @Override
074                        public Optional<CorsResponse> authorize(@NonNull McpCorsContext context,
075                                                                                                                                                                                        @NonNull Cors cors) {
076                                requireNonNull(context);
077                                requireNonNull(cors);
078                                return Optional.empty();
079                        }
080
081                        @NonNull
082                        @Override
083                        public Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context,
084                                                                                                                                                                                                                                                @NonNull CorsPreflight corsPreflight,
085                                                                                                                                                                                                                                                @NonNull Set<@NonNull HttpMethod> availableHttpMethods) {
086                                requireNonNull(context);
087                                requireNonNull(corsPreflight);
088                                requireNonNull(availableHttpMethods);
089                                return Optional.empty();
090                        }
091                };
092        }
093
094        /**
095         * Acquires the conservative default authorizer that leaves non-browser MCP requests alone while rejecting browser CORS authorization.
096         *
097         * @return the default non-browser-only authorizer
098         */
099        @NonNull
100        static McpCorsAuthorizer nonBrowserClientsOnlyInstance() {
101                return rejectAllInstance();
102        }
103
104        /**
105         * Acquires an authorizer that allows all browser origins and always enables credentials.
106         *
107         * @return a permissive authorizer
108         */
109        @NonNull
110        static McpCorsAuthorizer acceptAllInstance() {
111                return fromOriginAuthorizer(context -> true, origin -> true);
112        }
113
114        /**
115         * Acquires an authorizer that allows only the provided normalized origins and disables credentials by default.
116         *
117         * @param whitelistedOrigins the origins to allow
118         * @return a whitelisting authorizer
119         */
120        @NonNull
121        static McpCorsAuthorizer fromWhitelistedOrigins(@NonNull Set<@NonNull String> whitelistedOrigins) {
122                return fromWhitelistedOrigins(whitelistedOrigins, origin -> false);
123        }
124
125        /**
126         * Acquires an authorizer that allows only the provided normalized origins and delegates credential behavior per origin.
127         *
128         * @param whitelistedOrigins the origins to allow
129         * @param allowCredentialsResolver resolves whether credentials should be allowed for an origin
130         * @return a whitelisting authorizer
131         */
132        @NonNull
133        static McpCorsAuthorizer fromWhitelistedOrigins(@NonNull Set<@NonNull String> whitelistedOrigins,
134                                                                                                                                                                                                        @NonNull Function<String, Boolean> allowCredentialsResolver) {
135                requireNonNull(whitelistedOrigins);
136                requireNonNull(allowCredentialsResolver);
137
138                Set<String> normalizedOrigins = new LinkedHashSet<>();
139
140                for (String whitelistedOrigin : whitelistedOrigins) {
141                        requireNonNull(whitelistedOrigin);
142                        normalizedOrigins.add(normalizeOrigin(whitelistedOrigin));
143                }
144
145                return fromOriginAuthorizer(context -> {
146                        requireNonNull(context);
147
148                        if (context.origin() == null)
149                                return false;
150
151                        return normalizedOrigins.contains(normalizeOrigin(context.origin()));
152                }, allowCredentialsResolver);
153        }
154
155        /**
156         * Acquires an authorizer backed by an origin-authorization predicate and with credentials disabled by default.
157         *
158         * @param originAuthorizer the origin predicate
159         * @return an authorizer backed by the predicate
160         */
161        @NonNull
162        static McpCorsAuthorizer fromOriginAuthorizer(@NonNull Predicate<@NonNull McpCorsContext> originAuthorizer) {
163                return fromOriginAuthorizer(originAuthorizer, origin -> false);
164        }
165
166        /**
167         * Acquires an authorizer backed by an origin-authorization predicate plus a credentials resolver.
168         *
169         * @param originAuthorizer the origin predicate
170         * @param allowCredentialsResolver resolves whether credentials should be allowed for an origin
171         * @return an authorizer backed by the supplied callbacks
172         */
173        @NonNull
174        static McpCorsAuthorizer fromOriginAuthorizer(@NonNull Predicate<@NonNull McpCorsContext> originAuthorizer,
175                                                                                                                                                                                                                @NonNull Function<String, Boolean> allowCredentialsResolver) {
176                requireNonNull(originAuthorizer);
177                requireNonNull(allowCredentialsResolver);
178
179                return new McpCorsAuthorizer() {
180                        @NonNull
181                        @Override
182                        public Optional<CorsResponse> authorize(@NonNull McpCorsContext context,
183                                                                                                                                                                                        @NonNull Cors cors) {
184                                requireNonNull(context);
185                                requireNonNull(cors);
186
187                                if (!originAuthorizer.test(context))
188                                        return Optional.empty();
189
190                                return Optional.of(CorsResponse.withAccessControlAllowOrigin(cors.getOrigin())
191                                                .accessControlAllowCredentials(allowCredentialsResolver.apply(normalizeOrigin(cors.getOrigin())))
192                                                .accessControlExposeHeaders(defaultExposedHeaders())
193                                                .build());
194                        }
195
196                        @NonNull
197                        @Override
198                        public Optional<CorsPreflightResponse> authorizePreflight(@NonNull McpCorsContext context,
199                                                                                                                                                                                                                                                @NonNull CorsPreflight corsPreflight,
200                                                                                                                                                                                                                                                @NonNull Set<@NonNull HttpMethod> availableHttpMethods) {
201                                requireNonNull(context);
202                                requireNonNull(corsPreflight);
203                                requireNonNull(availableHttpMethods);
204
205                                if (!originAuthorizer.test(context))
206                                        return Optional.empty();
207
208                                return Optional.of(CorsPreflightResponse.withAccessControlAllowOrigin(corsPreflight.getOrigin())
209                                                .accessControlAllowMethods(availableHttpMethods)
210                                                .accessControlAllowHeaders(corsPreflight.getAccessControlRequestHeaders())
211                                                .accessControlAllowCredentials(allowCredentialsResolver.apply(normalizeOrigin(corsPreflight.getOrigin())))
212                                                .accessControlMaxAge(Duration.ofMinutes(10))
213                                                .build());
214                        }
215                };
216        }
217
218        @NonNull
219        private static Set<String> defaultExposedHeaders() {
220                return Set.of("MCP-Session-Id", "WWW-Authenticate");
221        }
222
223        @NonNull
224        private static String normalizeOrigin(@NonNull String origin) {
225                requireNonNull(origin);
226
227                if ("null".equals(origin))
228                        return "null";
229
230                URI uri = URI.create(origin.trim());
231                String scheme = uri.getScheme() == null ? "" : uri.getScheme().toLowerCase(Locale.ROOT);
232                String host = uri.getHost() == null ? "" : uri.getHost().toLowerCase(Locale.ROOT);
233                Integer port = uri.getPort() == -1 ? null : uri.getPort();
234
235                if (("http".equals(scheme) && Integer.valueOf(80).equals(port))
236                                || ("https".equals(scheme) && Integer.valueOf(443).equals(port)))
237                        port = null;
238
239                return port == null ? "%s://%s".formatted(scheme, host) : "%s://%s:%s".formatted(scheme, host, port);
240        }
241}