View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.syncope.core.spring.security.jws;
20  
21  import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
22  import static org.junit.jupiter.api.Assertions.assertEquals;
23  import static org.junit.jupiter.api.Assertions.assertNotNull;
24  import static org.junit.jupiter.api.Assertions.assertTrue;
25  import static org.mockito.ArgumentMatchers.anyString;
26  import static org.mockito.Mockito.doAnswer;
27  import static org.mockito.Mockito.spy;
28  
29  import com.nimbusds.jose.JOSEException;
30  import com.nimbusds.jose.JOSEObjectType;
31  import com.nimbusds.jose.JWSAlgorithm;
32  import com.nimbusds.jose.JWSHeader;
33  import com.nimbusds.jose.JWSSigner;
34  import com.nimbusds.jose.JWSVerifier;
35  import com.nimbusds.jose.crypto.ECDSASigner;
36  import com.nimbusds.jose.crypto.RSASSASigner;
37  import com.nimbusds.jose.jwk.Curve;
38  import com.nimbusds.jose.jwk.ECKey;
39  import com.nimbusds.jose.jwk.JWK;
40  import com.nimbusds.jose.jwk.KeyUse;
41  import com.nimbusds.jose.jwk.RSAKey;
42  import com.nimbusds.jose.util.Base64URL;
43  import com.nimbusds.jwt.JWTClaimsSet;
44  import com.nimbusds.jwt.SignedJWT;
45  import java.security.InvalidAlgorithmParameterException;
46  import java.security.KeyPair;
47  import java.security.KeyPairGenerator;
48  import java.security.NoSuchAlgorithmException;
49  import java.security.interfaces.ECPrivateKey;
50  import java.security.interfaces.ECPublicKey;
51  import java.security.interfaces.RSAPrivateKey;
52  import java.security.interfaces.RSAPublicKey;
53  import java.time.Duration;
54  import java.util.Date;
55  import java.util.Map;
56  import java.util.Set;
57  import java.util.UUID;
58  import org.junit.jupiter.api.Test;
59  import org.junit.jupiter.api.extension.ExtendWith;
60  import org.mockito.junit.jupiter.MockitoExtension;
61  
62  @ExtendWith(MockitoExtension.class)
63  public class MSEntraAccessTokenJWSVerifierTest {
64  
65      private static class SpyableMSEntraAccessTokenJWSVerifier extends MSEntraAccessTokenJWSVerifier {
66  
67          SpyableMSEntraAccessTokenJWSVerifier() {
68              super(null, null, Duration.ofHours(24));
69          }
70      }
71  
72      private static final String TENANT_ID = "test-tenant-id";
73  
74      private static final String APP_ID = "test-app-id";
75  
76      private static String createSignedJWT(final JWK jwk) throws JOSEException {
77          // Create JWT header
78          JWSHeader header = new JWSHeader.Builder((JWSAlgorithm) jwk.getAlgorithm())
79                  .type(JOSEObjectType.JWT)
80                  .keyID(jwk.getKeyID())
81                  .build();
82  
83          // Create JWT payload
84          JWTClaimsSet payload = new JWTClaimsSet.Builder()
85                  .issuer(TENANT_ID)
86                  .audience(APP_ID)
87                  .build();
88  
89          // Create signed JWT
90          SignedJWT signedJWT = new SignedJWT(header, payload);
91  
92          JWSSigner signer = jwk.getAlgorithm() == JWSAlgorithm.RS256
93                  ? new RSASSASigner(jwk.toRSAKey())
94                  : new ECDSASigner(jwk.toECKey());
95  
96          signedJWT.sign(signer);
97          return signedJWT.serialize();
98      }
99  
100     private static MSEntraAccessTokenJWSVerifier getSpyInstance(
101             final String jwksUri, final String oidc, final String jwks) {
102 
103         MSEntraAccessTokenJWSVerifier v = spy(SpyableMSEntraAccessTokenJWSVerifier.class);
104         doAnswer(m -> m.getArgument(0).equals(jwksUri) ? jwks : oidc).when(v).fetchDocument(anyString());
105         return v;
106     }
107 
108     private static JWK generateJWKRSA() throws NoSuchAlgorithmException {
109         KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA");
110         gen.initialize(2048);
111         KeyPair keyPair = gen.generateKeyPair();
112 
113         // Convert to JWK format
114         return new RSAKey.Builder((RSAPublicKey) keyPair.getPublic())
115                 .privateKey((RSAPrivateKey) keyPair.getPrivate())
116                 .keyUse(KeyUse.SIGNATURE)
117                 .algorithm(JWSAlgorithm.RS256)
118                 .keyID(UUID.randomUUID().toString())
119                 .issueTime(new Date())
120                 .build();
121     }
122 
123     private static JWK generateJWKEC() throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
124         // Generate EC key pair with P-256 curve
125         KeyPairGenerator gen = KeyPairGenerator.getInstance("EC");
126         gen.initialize(Curve.P_256.toECParameterSpec());
127         KeyPair keyPair = gen.generateKeyPair();
128 
129         // Convert to JWK format
130         return new ECKey.Builder(Curve.P_256, (ECPublicKey) keyPair.getPublic())
131                 .privateKey((ECPrivateKey) keyPair.getPrivate())
132                 .algorithm(JWSAlgorithm.ES256)
133                 .keyUse(KeyUse.SIGNATURE)
134                 .keyID(UUID.randomUUID().toString())
135                 .issueTime(new Date())
136                 .build();
137     }
138 
139     @Test
140     void getOpenIDMetadataDocumentUrl() {
141         // Tenant id and app id
142         MSEntraAccessTokenJWSVerifier v1 = new MSEntraAccessTokenJWSVerifier(TENANT_ID, APP_ID, Duration.ofHours(24));
143         assertEquals(String.format(
144                 "https://login.microsoftonline.com/%s/.well-known/openid-configuration?appid=%s", TENANT_ID, APP_ID),
145                 v1.getOpenIDMetadataDocumentUrl());
146 
147         // Tenant id, no app id
148         MSEntraAccessTokenJWSVerifier v2 = new MSEntraAccessTokenJWSVerifier(TENANT_ID, null, Duration.ofHours(24));
149         assertEquals(
150                 String.format("https://login.microsoftonline.com/%s/.well-known/openid-configuration", TENANT_ID),
151                 v2.getOpenIDMetadataDocumentUrl());
152 
153         // No tenant id, no app id
154         MSEntraAccessTokenJWSVerifier v3 = new MSEntraAccessTokenJWSVerifier(null, null, Duration.ofHours(24));
155         assertEquals(
156                 "https://login.microsoftonline.com/common/.well-known/openid-configuration",
157                 v3.getOpenIDMetadataDocumentUrl());
158     }
159 
160     @Test
161     void extractJwksUri() {
162         String doc = "{\"jwks_uri\": \"https://login.microsoftonline.com/common/discovery/keys\"}";
163 
164         MSEntraAccessTokenJWSVerifier v = new MSEntraAccessTokenJWSVerifier(TENANT_ID, APP_ID, Duration.ofHours(24));
165         assertEquals("https://login.microsoftonline.com/common/discovery/keys", v.extractJwksUri(doc));
166     }
167 
168     @Test
169     void parseJsonWebKeySetRSA() throws Exception {
170         // Create JWK, JWKS and jwt string
171         JWK jwk = generateJWKRSA();
172         String jwks = "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}";
173         String jwt = createSignedJWT(jwk);
174 
175         // Create JWSVerifier
176         MSEntraAccessTokenJWSVerifier v = new MSEntraAccessTokenJWSVerifier(
177                 "unknown-tenant-id", null, Duration.ofHours(24));
178 
179         assertDoesNotThrow(() -> v.parseJsonWebKeySet(jwks));
180 
181         Map<String, JWSVerifier> verifiersMap = v.parseJsonWebKeySet(jwks);
182         assertEquals(1, verifiersMap.size());
183         JWSVerifier v1 = verifiersMap.get(jwk.getKeyID());
184         assertNotNull(v1);
185         assertTrue(v1.supportedJWSAlgorithms().contains((JWSAlgorithm) jwk.getAlgorithm()));
186 
187         // Verify JWT
188         String[] chunks = jwt.split("\\.");
189         assertTrue(v1.verify(
190                 JWSHeader.parse(new Base64URL(chunks[0])),
191                 (chunks[0] + "." + chunks[1]).getBytes(),
192                 new Base64URL(chunks[2])));
193     }
194 
195     @Test
196     void parseJsonWebKeySetEC() throws Exception {
197         // Create JWK, JWKS and jwt string
198         JWK jwk = generateJWKEC();
199         String jwks = "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}";
200         String jwt = createSignedJWT(jwk);
201 
202         // Create JWSVerifier
203         MSEntraAccessTokenJWSVerifier v = new MSEntraAccessTokenJWSVerifier(
204                 "unknown-tenant-id", null, Duration.ofHours(24));
205 
206         assertDoesNotThrow(() -> v.parseJsonWebKeySet(jwks));
207         Map<String, JWSVerifier> verifiersMap = v.parseJsonWebKeySet(jwks);
208         assertEquals(1, verifiersMap.size());
209         JWSVerifier v1 = verifiersMap.get(jwk.getKeyID());
210         assertNotNull(v1);
211         assertTrue(v1.supportedJWSAlgorithms().contains((JWSAlgorithm) jwk.getAlgorithm()));
212 
213         // Verify JWT
214         String[] chunks = jwt.split("\\.");
215         assertTrue(v1.verify(
216                 JWSHeader.parse(new Base64URL(chunks[0])),
217                 (chunks[0] + "." + chunks[1]).getBytes(),
218                 new Base64URL(chunks[2])));
219     }
220 
221     @Test
222     void supportedJWSAlgorithmsEmpty() {
223         String jwksUri = "https://example.com/keys";
224         String oidc = "{\"jwks_uri\": \"" + jwksUri + "\"}";
225         String jwks = "{\"keys\": []}";
226 
227         MSEntraAccessTokenJWSVerifier v = getSpyInstance(jwksUri, oidc, jwks);
228 
229         assertTrue(v.supportedJWSAlgorithms().isEmpty());
230     }
231 
232     @Test
233     void supportedJWSAlgorithmsRSA() throws Exception {
234         JWK jwk = generateJWKRSA();
235         String[] chunks = createSignedJWT(jwk).split("\\.");
236 
237         String jwksUri = "https://example.com/keys";
238         MSEntraAccessTokenJWSVerifier v = getSpyInstance(
239                 jwksUri,
240                 "{\"jwks_uri\": \"" + jwksUri + "\"}",
241                 "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}");
242 
243         assertTrue(v.verify(
244                 JWSHeader.parse(new Base64URL(chunks[0])),
245                 (chunks[0] + "." + chunks[1]).getBytes(),
246                 new Base64URL(chunks[2])));
247         assertTrue(v.supportedJWSAlgorithms().contains((JWSAlgorithm) jwk.getAlgorithm()));
248         assertDoesNotThrow(v::getJCAContext);
249     }
250 
251     @Test
252     void supportedJWSAlgorithmsRSAJWSAlgorithm() throws Exception {
253         JWK jwk = generateJWKRSA();
254 
255         String jwksUri = "https://example.com/keys";
256         MSEntraAccessTokenJWSVerifier v = getSpyInstance(
257                 jwksUri,
258                 "{\"jwks_uri\": \"" + jwksUri + "\"}",
259                 "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}");
260 
261         assertTrue(v.supportedJWSAlgorithms().contains((JWSAlgorithm) jwk.getAlgorithm()));
262     }
263 
264     @Test
265     void supportedJWSAlgorithmsRSAJCAContext() throws NoSuchAlgorithmException, JOSEException {
266         JWK jwk = generateJWKRSA();
267 
268         String jwksUri = "https://example.com/keys";
269         String oidc = "{\"jwks_uri\": \"" + jwksUri + "\"}";
270         String jwks = "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}";
271 
272         MSEntraAccessTokenJWSVerifier v = getSpyInstance(jwksUri, oidc, jwks);
273 
274         assertDoesNotThrow(v::getJCAContext);
275     }
276 
277     @Test
278     void supportedJWSAlgorithmsEC() throws Exception {
279         JWK jwk = generateJWKEC();
280         String[] chunks = createSignedJWT(jwk).split("\\.");
281 
282         String jwksUri = "https://example.com/keys";
283         MSEntraAccessTokenJWSVerifier v = getSpyInstance(
284                 jwksUri,
285                 "{\"jwks_uri\": \"" + jwksUri + "\"}",
286                 "{\"keys\": [" + jwk.toPublicJWK().toJSONString() + "]}");
287 
288         assertTrue(v.verify(
289                 JWSHeader.parse(new Base64URL(chunks[0])),
290                 (chunks[0] + "." + chunks[1]).getBytes(),
291                 new Base64URL(chunks[2])
292         ));
293         assertTrue(v.supportedJWSAlgorithms().contains((JWSAlgorithm) jwk.getAlgorithm()));
294         assertDoesNotThrow(v::getJCAContext);
295     }
296 
297     @Test
298     void supportedJWSAlgorithmsMixed() throws Exception {
299         JWK jwkRSA = generateJWKRSA();
300         JWK jwkEC = generateJWKEC();
301         String[] chunksRSA = createSignedJWT(jwkRSA).split("\\.");
302         String[] chunksEC = createSignedJWT(jwkEC).split("\\.");
303 
304         String jwksUri = "https://example.com/keys";
305         MSEntraAccessTokenJWSVerifier v = getSpyInstance(jwksUri,
306                 "{\"jwks_uri\": \"" + jwksUri + "\"}",
307                 "{\"keys\": ["
308                 + jwkRSA.toPublicJWK().toJSONString() + ","
309                 + jwkEC.toPublicJWK().toJSONString()
310                 + "]}");
311 
312         // Verify with RSA
313         assertTrue(v.verify(
314                 JWSHeader.parse(new Base64URL(chunksRSA[0])),
315                 (chunksRSA[0] + "." + chunksRSA[1]).getBytes(),
316                 new Base64URL(chunksRSA[2])));
317 
318         // Verify with EC
319         assertTrue(v.verify(
320                 JWSHeader.parse(new Base64URL(chunksEC[0])),
321                 (chunksEC[0] + "." + chunksEC[1]).getBytes(),
322                 new Base64URL(chunksEC[2])));
323 
324         Set.of((JWSAlgorithm) jwkRSA.getAlgorithm(), (JWSAlgorithm) jwkEC.getAlgorithm()).
325                 forEach(jwsAlgorithm -> assertTrue(v.supportedJWSAlgorithms().contains(jwsAlgorithm)));
326 
327         assertDoesNotThrow(v::getJCAContext);
328     }
329 }