001/*
002 * Copyright 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.soklet;
018
019import org.jspecify.annotations.NonNull;
020import org.jspecify.annotations.Nullable;
021
022import javax.annotation.concurrent.ThreadSafe;
023import java.util.ArrayList;
024import java.util.Collection;
025import java.util.LinkedHashSet;
026import java.util.List;
027import java.util.Objects;
028import java.util.Optional;
029
030import static java.lang.String.format;
031import static java.util.Objects.requireNonNull;
032
033/**
034 * Parsed W3C trace context from {@code traceparent} and {@code tracestate} HTTP header values.
035 * <p>
036 * This type models the normalized trace context understood by Soklet. Future-version extension fields are ignored and
037 * not preserved.
038 *
039 * @author <a href="https://www.revetkn.com">Mark Allen</a>
040 */
041@ThreadSafe
042public final class TraceContext {
043        private static final int VERSION_00_TRACEPARENT_LENGTH = 55;
044        private static final String ZERO_TRACE_ID = "00000000000000000000000000000000";
045        private static final String ZERO_PARENT_ID = "0000000000000000";
046        private static final Integer SAMPLED_FLAG = 1;
047        private static final Integer MAX_TRACESTATE_ENTRIES = 32;
048        private static final Integer MAX_TRACESTATE_LENGTH = 512;
049        private static final Integer MAX_TRACESTATE_ENTRY_TRUNCATION_LENGTH = 128;
050        @NonNull
051        private final String traceId;
052        @NonNull
053        private final String parentId;
054        @NonNull
055        private final Integer traceFlags;
056        @NonNull
057        private final List<@NonNull TraceStateEntry> traceStateEntries;
058
059        /**
060         * Parses W3C trace context from physical HTTP header values.
061         *
062         * @param traceparentHeaderValues the physical {@code traceparent} header values
063         * @param tracestateHeaderValues  the physical {@code tracestate} header values, in arrival order
064         * @return the parsed trace context, or {@link Optional#empty()} if no valid {@code traceparent} is available
065         */
066        @NonNull
067        public static Optional<TraceContext> fromHeaderValues(@Nullable Collection<@NonNull String> traceparentHeaderValues,
068                                                                                                                                                                                                                         @Nullable List<@NonNull String> tracestateHeaderValues) {
069                if (traceparentHeaderValues == null || traceparentHeaderValues.size() != 1)
070                        return Optional.empty();
071
072                ParsedTraceparent parsedTraceparent = parseTraceparent(traceparentHeaderValues.iterator().next()).orElse(null);
073
074                if (parsedTraceparent == null)
075                        return Optional.empty();
076
077                return Optional.of(new TraceContext(
078                                parsedTraceparent.traceId(),
079                                parsedTraceparent.parentId(),
080                                parsedTraceparent.traceFlags(),
081                                parseTraceStateEntries(tracestateHeaderValues)));
082        }
083
084        private TraceContext(@NonNull String traceId,
085                                                                                         @NonNull String parentId,
086                                                                                         @NonNull Integer traceFlags,
087                                                                                         @NonNull List<@NonNull TraceStateEntry> traceStateEntries) {
088                this.traceId = requireNonNull(traceId);
089                this.parentId = requireNonNull(parentId);
090                this.traceFlags = requireNonNull(traceFlags);
091                this.traceStateEntries = List.copyOf(traceStateEntries);
092        }
093
094        /**
095         * Returns the 32-character lowercase hexadecimal trace identifier.
096         *
097         * @return the trace identifier
098         */
099        @NonNull
100        public String getTraceId() {
101                return this.traceId;
102        }
103
104        /**
105         * Returns the 16-character lowercase hexadecimal parent identifier.
106         *
107         * @return the parent identifier
108         */
109        @NonNull
110        public String getParentId() {
111                return this.parentId;
112        }
113
114        /**
115         * Returns the trace flags as an unsigned 8-bit value represented by an {@link Integer}.
116         *
117         * @return the trace flags, in the range {@code 0..255}
118         */
119        @NonNull
120        public Integer getTraceFlags() {
121                return this.traceFlags;
122        }
123
124        /**
125         * Is the W3C sampled flag set?
126         *
127         * @return {@code true} if the sampled flag is set
128         */
129        @NonNull
130        public Boolean isSampled() {
131                return (getTraceFlags() & SAMPLED_FLAG) == SAMPLED_FLAG;
132        }
133
134        /**
135         * Returns the normalized W3C {@code tracestate} entries.
136         *
137         * @return the trace-state entries, or an empty list if none are present
138         */
139        @NonNull
140        public List<@NonNull TraceStateEntry> getTraceStateEntries() {
141                return this.traceStateEntries;
142        }
143
144        /**
145         * Returns this context in W3C {@code traceparent} header value form.
146         *
147         * @return the {@code traceparent} header value
148         */
149        @NonNull
150        public String toTraceparentHeaderValue() {
151                return format("00-%s-%s-%02x", getTraceId(), getParentId(), getTraceFlags());
152        }
153
154        /**
155         * Returns this context's normalized W3C {@code tracestate} header value.
156         *
157         * @return the {@code tracestate} header value, or {@link Optional#empty()} if none is present
158         */
159        @NonNull
160        public Optional<String> toTracestateHeaderValue() {
161                if (getTraceStateEntries().isEmpty())
162                        return Optional.empty();
163
164                StringBuilder value = new StringBuilder();
165
166                for (TraceStateEntry entry : getTraceStateEntries()) {
167                        if (!value.isEmpty())
168                                value.append(',');
169
170                        value.append(entry.toHeaderMemberValue());
171                }
172
173                return Optional.of(value.toString());
174        }
175
176        @Override
177        @NonNull
178        public String toString() {
179                return format("%s{traceId=%s, parentId=%s, traceFlags=%s, traceStateEntryCount=%s}",
180                                getClass().getSimpleName(), getTraceId(), getParentId(), getTraceFlags(), getTraceStateEntries().size());
181        }
182
183        @Override
184        public boolean equals(@Nullable Object object) {
185                if (this == object)
186                        return true;
187
188                if (!(object instanceof TraceContext traceContext))
189                        return false;
190
191                return Objects.equals(getTraceId(), traceContext.getTraceId())
192                                && Objects.equals(getParentId(), traceContext.getParentId())
193                                && Objects.equals(getTraceFlags(), traceContext.getTraceFlags())
194                                && Objects.equals(getTraceStateEntries(), traceContext.getTraceStateEntries());
195        }
196
197        @Override
198        public int hashCode() {
199                return Objects.hash(getTraceId(), getParentId(), getTraceFlags(), getTraceStateEntries());
200        }
201
202        @NonNull
203        private static Optional<ParsedTraceparent> parseTraceparent(@Nullable String traceparentHeaderValue) {
204                if (traceparentHeaderValue == null || traceparentHeaderValue.length() < VERSION_00_TRACEPARENT_LENGTH)
205                        return Optional.empty();
206
207                if (traceparentHeaderValue.charAt(2) != '-'
208                                || !isLowercaseHex(traceparentHeaderValue.charAt(0))
209                                || !isLowercaseHex(traceparentHeaderValue.charAt(1)))
210                        return Optional.empty();
211
212                String version = traceparentHeaderValue.substring(0, 2);
213
214                if ("ff".equals(version))
215                        return Optional.empty();
216
217                boolean version00 = "00".equals(version);
218
219                if (version00 && traceparentHeaderValue.length() != VERSION_00_TRACEPARENT_LENGTH)
220                        return Optional.empty();
221
222                if (!version00 && traceparentHeaderValue.length() > VERSION_00_TRACEPARENT_LENGTH && traceparentHeaderValue.charAt(VERSION_00_TRACEPARENT_LENGTH) != '-')
223                        return Optional.empty();
224
225                if (traceparentHeaderValue.charAt(35) != '-' || traceparentHeaderValue.charAt(52) != '-')
226                        return Optional.empty();
227
228                String traceId = traceparentHeaderValue.substring(3, 35);
229                String parentId = traceparentHeaderValue.substring(36, 52);
230                String traceFlagsText = traceparentHeaderValue.substring(53, 55);
231
232                if (!isLowercaseHex(traceId)
233                                || !isLowercaseHex(parentId)
234                                || !isLowercaseHex(traceFlagsText)
235                                || ZERO_TRACE_ID.equals(traceId)
236                                || ZERO_PARENT_ID.equals(parentId))
237                        return Optional.empty();
238
239                int traceFlags = Integer.parseInt(traceFlagsText, 16);
240
241                if (!version00)
242                        traceFlags = traceFlags & SAMPLED_FLAG;
243
244                return Optional.of(new ParsedTraceparent(traceId, parentId, traceFlags));
245        }
246
247        @NonNull
248        private static List<@NonNull TraceStateEntry> parseTraceStateEntries(@Nullable List<@NonNull String> tracestateHeaderValues) {
249                if (tracestateHeaderValues == null || tracestateHeaderValues.isEmpty())
250                        return List.of();
251
252                List<TraceStateEntry> entries = new ArrayList<>();
253                LinkedHashSet<String> keys = new LinkedHashSet<>();
254
255                for (String tracestateHeaderValue : tracestateHeaderValues) {
256                        if (tracestateHeaderValue == null)
257                                continue;
258
259                        for (String member : tracestateHeaderValue.split(",", -1)) {
260                                TraceStateEntry entry = TraceStateEntry.fromMember(member).orElse(null);
261
262                                if (entry == null || keys.contains(entry.getKey()))
263                                        continue;
264
265                                keys.add(entry.getKey());
266                                entries.add(entry);
267                        }
268                }
269
270                return truncateTraceStateEntries(entries);
271        }
272
273        @NonNull
274        private static List<@NonNull TraceStateEntry> truncateTraceStateEntries(@NonNull List<@NonNull TraceStateEntry> entries) {
275                requireNonNull(entries);
276
277                if (entries.isEmpty())
278                        return List.of();
279
280                List<TraceStateEntry> normalizedEntries = new ArrayList<>(entries);
281
282                int largeEntryIndex;
283                while (isOverTraceStateLimits(normalizedEntries)
284                                && (largeEntryIndex = lastLargeTraceStateEntryIndex(normalizedEntries)) >= 0)
285                        normalizedEntries.remove(largeEntryIndex);
286
287                while (normalizedEntries.size() > MAX_TRACESTATE_ENTRIES)
288                        normalizedEntries.remove(normalizedEntries.size() - 1);
289
290                while (traceStateLength(normalizedEntries) > MAX_TRACESTATE_LENGTH)
291                        normalizedEntries.remove(normalizedEntries.size() - 1);
292
293                return List.copyOf(normalizedEntries);
294        }
295
296        private static boolean isOverTraceStateLimits(@NonNull List<@NonNull TraceStateEntry> entries) {
297                requireNonNull(entries);
298                return entries.size() > MAX_TRACESTATE_ENTRIES || traceStateLength(entries) > MAX_TRACESTATE_LENGTH;
299        }
300
301        private static int lastLargeTraceStateEntryIndex(@NonNull List<@NonNull TraceStateEntry> entries) {
302                requireNonNull(entries);
303
304                for (int i = entries.size() - 1; i >= 0; i--)
305                        if (entries.get(i).toHeaderMemberValue().length() > MAX_TRACESTATE_ENTRY_TRUNCATION_LENGTH)
306                                return i;
307
308                return -1;
309        }
310
311        private static int traceStateLength(@NonNull List<@NonNull TraceStateEntry> entries) {
312                requireNonNull(entries);
313
314                int length = 0;
315
316                for (TraceStateEntry entry : entries) {
317                        if (length > 0)
318                                length++;
319
320                        length += entry.toHeaderMemberValue().length();
321                }
322
323                return length;
324        }
325
326        private static boolean isLowercaseHex(@NonNull String value) {
327                requireNonNull(value);
328
329                for (int i = 0; i < value.length(); i++)
330                        if (!isLowercaseHex(value.charAt(i)))
331                                return false;
332
333                return true;
334        }
335
336        private static boolean isLowercaseHex(char c) {
337                return (c >= '0' && c <= '9')
338                                || (c >= 'a' && c <= 'f');
339        }
340
341        private record ParsedTraceparent(@NonNull String traceId,
342                                                                                                                                         @NonNull String parentId,
343                                                                                                                                         @NonNull Integer traceFlags) {
344                // Value record
345        }
346}