001/*
002 * Copyright 2022-2025 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet.core.impl;
018
019import com.soklet.core.Cors;
020import com.soklet.core.CorsPreflight;
021import com.soklet.core.CorsPreflightResponse;
022import com.soklet.core.CorsResponse;
023import com.soklet.core.HttpMethod;
024import com.soklet.core.MarshaledResponse;
025import com.soklet.core.Request;
026import com.soklet.core.ResourceMethod;
027import com.soklet.core.Response;
028import com.soklet.core.ResponseMarshaler;
029import com.soklet.core.StatusCode;
030import com.soklet.exception.BadRequestException;
031import com.soklet.internal.spring.LinkedCaseInsensitiveMap;
032
033import javax.annotation.Nonnull;
034import javax.annotation.Nullable;
035import javax.annotation.concurrent.ThreadSafe;
036import java.nio.charset.Charset;
037import java.nio.charset.StandardCharsets;
038import java.time.Duration;
039import java.util.LinkedHashMap;
040import java.util.LinkedHashSet;
041import java.util.Map;
042import java.util.Set;
043import java.util.SortedSet;
044import java.util.TreeSet;
045import java.util.stream.Collectors;
046
047import static com.soklet.core.Utilities.emptyByteArray;
048import static java.lang.String.format;
049import static java.util.Objects.requireNonNull;
050
051/**
052 * @author <a href="https://www.revetkn.com">Mark Allen</a>
053 */
054@ThreadSafe
055public class DefaultResponseMarshaler implements ResponseMarshaler {
056        @Nonnull
057        private static final DefaultResponseMarshaler SHARED_INSTANCE;
058        @Nonnull
059        private static final Charset DEFAULT_CHARSET;
060
061        static {
062                DEFAULT_CHARSET = StandardCharsets.UTF_8;
063                SHARED_INSTANCE = new DefaultResponseMarshaler();
064        }
065
066        @Nonnull
067        private final Charset charset;
068
069        public DefaultResponseMarshaler() {
070                this(null);
071        }
072
073        public DefaultResponseMarshaler(@Nullable Charset charset) {
074                this.charset = charset == null ? DEFAULT_CHARSET : charset;
075        }
076
077        @Nonnull
078        public static DefaultResponseMarshaler sharedInstance() {
079                return SHARED_INSTANCE;
080        }
081
082        @Nonnull
083        @Override
084        public MarshaledResponse forHappyPath(@Nonnull Request request,
085                                                                                                                                                                @Nonnull Response response,
086                                                                                                                                                                @Nonnull ResourceMethod resourceMethod) {
087                requireNonNull(request);
088                requireNonNull(response);
089                requireNonNull(resourceMethod);
090
091                byte[] body = null;
092                Object bodyAsObject = response.getBody().orElse(null);
093                boolean binaryResponse = false;
094
095                // If response body is a byte array, pass through as-is.
096                // Otherwise, default representation is toString() output.
097                // Real systems would use a different representation, e.g. JSON
098                if (bodyAsObject != null) {
099                        if (bodyAsObject instanceof byte[]) {
100                                body = (byte[]) bodyAsObject;
101                                binaryResponse = true;
102                        } else {
103                                body = bodyAsObject.toString().getBytes(getCharset());
104                        }
105                }
106
107                Map<String, Set<String>> headers = new LinkedCaseInsensitiveMap<>(response.getHeaders());
108
109                // If no Content-Type specified, supply a default
110                if (!headers.keySet().contains("Content-Type"))
111                        headers.put("Content-Type", Set.of(binaryResponse ? "application/octet-stream" : format("text/plain; charset=%s", getCharset().name())));
112
113                return MarshaledResponse.withStatusCode(response.getStatusCode())
114                                .headers(headers)
115                                .cookies(response.getCookies())
116                                .body(body)
117                                .build();
118        }
119
120        @Nonnull
121        @Override
122        public MarshaledResponse forNotFound(@Nonnull Request request) {
123                requireNonNull(request);
124
125                Integer statusCode = 404;
126
127                return MarshaledResponse.withStatusCode(statusCode)
128                                .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", getCharset().name()))))
129                                .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(getCharset()))
130                                .build();
131        }
132
133        @Nonnull
134        @Override
135        public MarshaledResponse forMethodNotAllowed(@Nonnull Request request,
136                                                                                                                                                                                         @Nonnull Set<HttpMethod> allowedHttpMethods) {
137                requireNonNull(request);
138                requireNonNull(allowedHttpMethods);
139
140                SortedSet<String> allowedHttpMethodsAsStrings = new TreeSet<>(allowedHttpMethods.stream()
141                                .map(httpMethod -> httpMethod.name())
142                                .collect(Collectors.toSet()));
143
144                Integer statusCode = 405;
145
146                Map<String, Set<String>> headers = new LinkedHashMap<>();
147                headers.put("Allow", allowedHttpMethodsAsStrings);
148                headers.put("Content-Type", Set.of(format("text/plain; charset=%s", getCharset().name())));
149
150                return MarshaledResponse.withStatusCode(statusCode)
151                                .headers(headers)
152                                .body(format("HTTP %d: %s. Requested: %s, Allowed: %s",
153                                                statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase(), request.getHttpMethod().name(),
154                                                String.join(", ", allowedHttpMethodsAsStrings)).getBytes(getCharset()))
155                                .build();
156        }
157
158        @Nonnull
159        @Override
160        public MarshaledResponse forContentTooLarge(@Nonnull Request request,
161                                                                                                                                                                                        @Nullable ResourceMethod resourceMethod) {
162                requireNonNull(request);
163
164                Integer statusCode = 413;
165
166                return MarshaledResponse.withStatusCode(statusCode)
167                                .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", getCharset().name()))))
168                                .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(getCharset()))
169                                .build();
170        }
171
172        @Nonnull
173        @Override
174        public MarshaledResponse forOptions(@Nonnull Request request,
175                                                                                                                                                        @Nonnull Set<HttpMethod> allowedHttpMethods) {
176                requireNonNull(request);
177                requireNonNull(allowedHttpMethods);
178
179                SortedSet<String> allowedHttpMethodsAsStrings = new TreeSet<>(allowedHttpMethods.stream()
180                                .map(httpMethod -> httpMethod.name())
181                                .collect(Collectors.toSet()));
182
183                return MarshaledResponse.withStatusCode(204)
184                                .headers(Map.of("Allow", allowedHttpMethodsAsStrings))
185                                .build();
186        }
187
188        @Nonnull
189        @Override
190        public MarshaledResponse forHead(@Nonnull Request request,
191                                                                                                                                         @Nonnull MarshaledResponse getMethodMarshaledResponse) {
192                requireNonNull(request);
193                requireNonNull(getMethodMarshaledResponse);
194
195                // A HEAD can never write a response body, but we explicitly set its Content-Length header
196                // so the client knows how long the response would have been.
197                return getMethodMarshaledResponse.copy()
198                                .body(null)
199                                .headers((mutableHeaders) -> {
200                                        byte[] responseBytes = getMethodMarshaledResponse.getBody().orElse(emptyByteArray());
201                                        mutableHeaders.put("Content-Length", Set.of(String.valueOf(responseBytes.length)));
202                                }).finish();
203        }
204
205        @Nonnull
206        @Override
207        public MarshaledResponse forThrowable(@Nonnull Request request,
208                                                                                                                                                                @Nonnull Throwable throwable,
209                                                                                                                                                                @Nullable ResourceMethod resourceMethod) {
210                requireNonNull(request);
211                requireNonNull(throwable);
212
213                Integer statusCode = throwable instanceof BadRequestException ? 400 : 500;
214
215                return MarshaledResponse.withStatusCode(statusCode)
216                                .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", getCharset().name()))))
217                                .body(format("HTTP %d: %s", statusCode, StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(getCharset()))
218                                .build();
219        }
220
221        @Nonnull
222        @Override
223        public MarshaledResponse forCorsPreflightAllowed(@Nonnull Request request,
224                                                                                                                                                                                                         @Nonnull CorsPreflight corsPreflight,
225                                                                                                                                                                                                         @Nonnull CorsPreflightResponse corsPreflightResponse) {
226                requireNonNull(request);
227                requireNonNull(corsPreflight);
228                requireNonNull(corsPreflightResponse);
229
230                Integer statusCode = 204;
231                Map<String, Set<String>> headers = new LinkedHashMap<>();
232
233                headers.put("Access-Control-Allow-Origin", Set.of(corsPreflightResponse.getAccessControlAllowOrigin()));
234
235                Boolean accessControlAllowCredentials = corsPreflightResponse.getAccessControlAllowCredentials().orElse(null);
236
237                // Either "true" or omit entirely
238                if (accessControlAllowCredentials != null && accessControlAllowCredentials)
239                        headers.put("Access-Control-Allow-Credentials", Set.of("true"));
240
241                Set<String> accessControlAllowHeaders = corsPreflightResponse.getAccessControlAllowHeaders();
242
243                if (accessControlAllowHeaders.size() > 0)
244                        headers.put("Access-Control-Allow-Headers", new LinkedHashSet<>(accessControlAllowHeaders));
245
246                Set<String> accessControlAllowMethodAsStrings = new LinkedHashSet<>();
247
248                for (HttpMethod httpMethod : corsPreflightResponse.getAccessControlAllowMethods())
249                        accessControlAllowMethodAsStrings.add(httpMethod.name());
250
251                if (accessControlAllowMethodAsStrings.size() > 0)
252                        headers.put("Access-Control-Allow-Methods", accessControlAllowMethodAsStrings);
253
254                Duration accessControlMaxAge = corsPreflightResponse.getAccessControlMaxAge().orElse(null);
255
256                if (accessControlMaxAge != null)
257                        headers.put("Access-Control-Max-Age", Set.of(String.valueOf(accessControlMaxAge.toSeconds())));
258
259                return MarshaledResponse.withStatusCode(statusCode)
260                                .headers(headers)
261                                .build();
262        }
263
264        @Nonnull
265        @Override
266        public MarshaledResponse forCorsPreflightRejected(@Nonnull Request request,
267                                                                                                                                                                                                                @Nonnull CorsPreflight corsPreflight) {
268                requireNonNull(request);
269                requireNonNull(corsPreflight);
270
271                Integer statusCode = 403;
272
273                return MarshaledResponse.withStatusCode(statusCode)
274                                .headers(Map.of("Content-Type", Set.of(format("text/plain; charset=%s", getCharset().name()))))
275                                .body(format("HTTP %d: %s (CORS preflight rejected)", statusCode,
276                                                StatusCode.fromStatusCode(statusCode).get().getReasonPhrase()).getBytes(getCharset()))
277                                .build();
278        }
279
280        @Nonnull
281        @Override
282        public MarshaledResponse forCorsAllowed(@Nonnull Request request,
283                                                                                                                                                                        @Nonnull Cors cors,
284                                                                                                                                                                        @Nonnull CorsResponse corsResponse,
285                                                                                                                                                                        @Nonnull MarshaledResponse marshaledResponse) {
286                requireNonNull(request);
287                requireNonNull(cors);
288                requireNonNull(corsResponse);
289                requireNonNull(marshaledResponse);
290
291                return marshaledResponse.copy()
292                                .headers((mutableHeaders) -> {
293                                        mutableHeaders.put("Access-Control-Allow-Origin", Set.of(corsResponse.getAccessControlAllowOrigin()));
294
295                                        Boolean accessControlAllowCredentials = corsResponse.getAccessControlAllowCredentials().orElse(null);
296
297                                        // Either "true" or omit entirely
298                                        if (accessControlAllowCredentials != null && accessControlAllowCredentials)
299                                                mutableHeaders.put("Access-Control-Allow-Credentials", Set.of("true"));
300
301                                        Set<String> accessControlExposeHeaders = corsResponse.getAccessControlExposeHeaders();
302
303                                        if (accessControlExposeHeaders.size() > 0)
304                                                mutableHeaders.put("Access-Control-Expose-Headers", new LinkedHashSet<>(accessControlExposeHeaders));
305                                }).finish();
306        }
307
308        @Nonnull
309        protected Charset getCharset() {
310                return this.charset;
311        }
312}