1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.syncope.core.spring.security;
20
21 import com.nimbusds.jose.JOSEException;
22 import com.nimbusds.jwt.SignedJWT;
23 import java.io.IOException;
24 import java.text.ParseException;
25 import java.util.Optional;
26 import java.util.Set;
27 import javax.servlet.FilterChain;
28 import javax.servlet.ServletException;
29 import javax.servlet.http.HttpServletRequest;
30 import javax.servlet.http.HttpServletResponse;
31 import javax.ws.rs.core.HttpHeaders;
32 import org.apache.commons.lang3.tuple.Pair;
33 import org.slf4j.Logger;
34 import org.slf4j.LoggerFactory;
35 import org.springframework.security.authentication.AuthenticationManager;
36 import org.springframework.security.authentication.BadCredentialsException;
37 import org.springframework.security.core.AuthenticationException;
38 import org.springframework.security.core.context.SecurityContextHolder;
39 import org.springframework.security.web.AuthenticationEntryPoint;
40 import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
41
42
43
44
45
46 public class JWTAuthenticationFilter extends BasicAuthenticationFilter {
47
48 private static final Logger LOG = LoggerFactory.getLogger(JWTAuthenticationFilter.class);
49
50 private final AuthenticationEntryPoint authenticationEntryPoint;
51
52 private final SyncopeAuthenticationDetailsSource authenticationDetailsSource;
53
54 private final AuthDataAccessor dataAccessor;
55
56 private final DefaultCredentialChecker credentialChecker;
57
58 public JWTAuthenticationFilter(
59 final AuthenticationManager authenticationManager,
60 final AuthenticationEntryPoint authenticationEntryPoint,
61 final SyncopeAuthenticationDetailsSource authenticationDetailsSource,
62 final AuthDataAccessor dataAccessor,
63 final DefaultCredentialChecker credentialChecker) {
64
65 super(authenticationManager);
66 this.authenticationEntryPoint = authenticationEntryPoint;
67 this.authenticationDetailsSource = authenticationDetailsSource;
68 this.dataAccessor = dataAccessor;
69 this.credentialChecker = credentialChecker;
70 }
71
72 @Override
73 protected void doFilterInternal(
74 final HttpServletRequest request,
75 final HttpServletResponse response,
76 final FilterChain chain)
77 throws ServletException, IOException {
78
79 String auth = request.getHeader(HttpHeaders.AUTHORIZATION);
80 String[] parts = Optional.ofNullable(auth).map(s -> s.split(" ")).orElse(null);
81 if (parts == null || parts.length != 2 || !"Bearer".equals(parts[0])) {
82 chain.doFilter(request, response);
83 return;
84 }
85
86 String stringToken = parts[1];
87 LOG.debug("JWT received: {}", stringToken);
88
89 try {
90 credentialChecker.checkIsDefaultJWSKeyInUse();
91
92 SignedJWT jwt = SignedJWT.parse(stringToken);
93 JWTSSOProvider jwtSSOProvider = dataAccessor.getJWTSSOProvider(jwt.getJWTClaimsSet().getIssuer());
94 if (!jwt.verify(jwtSSOProvider)) {
95 throw new BadCredentialsException("Invalid signature found in JWT");
96 }
97
98 JWTAuthentication jwtAuthentication =
99 new JWTAuthentication(jwt.getJWTClaimsSet(), authenticationDetailsSource.buildDetails(request));
100 AuthContextUtils.callAsAdmin(jwtAuthentication.getDetails().getDomain(), () -> {
101 Pair<String, Set<SyncopeGrantedAuthority>> authenticated = dataAccessor.authenticate(jwtAuthentication);
102 jwtAuthentication.setUsername(authenticated.getLeft());
103 jwtAuthentication.getAuthorities().addAll(authenticated.getRight());
104 return null;
105 });
106 SecurityContextHolder.getContext().setAuthentication(jwtAuthentication);
107
108 chain.doFilter(request, response);
109 } catch (ParseException | JOSEException e) {
110 SecurityContextHolder.clearContext();
111 this.authenticationEntryPoint.commence(
112 request, response, new BadCredentialsException("Invalid JWT: " + stringToken, e));
113 } catch (AuthenticationException e) {
114 SecurityContextHolder.clearContext();
115 this.authenticationEntryPoint.commence(request, response, e);
116 }
117 }
118 }