1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.apache.amber.oauth2.rsfilter;
23
24 import org.apache.amber.oauth2.common.OAuth;
25 import org.apache.amber.oauth2.common.error.OAuthError;
26 import org.apache.amber.oauth2.common.exception.OAuthProblemException;
27 import org.apache.amber.oauth2.common.exception.OAuthSystemException;
28 import org.apache.amber.oauth2.common.message.OAuthResponse;
29 import org.apache.amber.oauth2.common.message.types.ParameterStyle;
30 import org.apache.amber.oauth2.rs.request.OAuthAccessResourceRequest;
31 import org.apache.amber.oauth2.rs.response.OAuthRSResponse;
32
33 import javax.servlet.*;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletRequestWrapper;
36 import javax.servlet.http.HttpServletResponse;
37 import java.io.IOException;
38 import java.security.Principal;
39
40
41
42
43
44
45 public class OAuthFilter implements Filter {
46
47 public static final String OAUTH_RS_PROVIDER_CLASS = "oauth.rs.provider-class";
48
49 public static final String RS_REALM = "oauth.rs.realm";
50 public static final String RS_REALM_DEFAULT = "OAuth Protected Service";
51
52 public static final String RS_TOKENS = "oauth.rs.tokens";
53 public static final ParameterStyle RS_TOKENS_DEFAULT = ParameterStyle.HEADER;
54
55 private static final String TOKEN_DELIMITER = ",";
56
57 private String realm;
58
59 private OAuthRSProvider provider;
60
61 private ParameterStyle[] parameterStyles;
62
63
64 @Override
65 public void init(FilterConfig filterConfig) throws ServletException {
66
67 provider = OAuthUtils
68 .initiateServletContext(filterConfig.getServletContext(), OAUTH_RS_PROVIDER_CLASS,
69 OAuthRSProvider.class);
70 realm = filterConfig.getServletContext().getInitParameter(RS_REALM);
71 if (OAuthUtils.isEmpty(realm)) {
72 realm = RS_REALM_DEFAULT;
73 }
74
75 String parameterStylesString = filterConfig.getServletContext().getInitParameter(RS_TOKENS);
76 if (OAuthUtils.isEmpty(parameterStylesString)) {
77 parameterStyles = new ParameterStyle[] {RS_TOKENS_DEFAULT};
78 } else {
79 String[] parameters = parameterStylesString.split(TOKEN_DELIMITER);
80 if (parameters != null && parameters.length > 0) {
81 parameterStyles = new ParameterStyle[parameters.length];
82 for (int i = 0; i < parameters.length; i++) {
83 ParameterStyle tempParameterStyle = ParameterStyle.valueOf(parameters[i]);
84 if (tempParameterStyle != null) {
85 parameterStyles[i] = tempParameterStyle;
86 } else {
87 throw new ServletException("Incorrect ParameterStyle: " + parameters[i]);
88 }
89 }
90 }
91 }
92
93 }
94
95 @Override
96 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
97 throws IOException, ServletException {
98 HttpServletRequest req = (HttpServletRequest)request;
99 HttpServletResponse res = (HttpServletResponse)response;
100
101 try {
102
103
104 OAuthAccessResourceRequest oauthRequest = new OAuthAccessResourceRequest(req,
105 parameterStyles);
106
107
108 String accessToken = oauthRequest.getAccessToken();
109
110 final OAuthDecision decision = provider.validateRequest(realm, accessToken, req);
111
112 final Principal principal = decision.getPrincipal();
113
114 request = new HttpServletRequestWrapper((HttpServletRequest)request) {
115 @Override
116 public String getRemoteUser() {
117 return principal != null ? principal.getName() : null;
118 }
119 @Override
120 public Principal getUserPrincipal() {
121 return principal;
122 }
123
124 };
125
126 request.setAttribute(OAuth.OAUTH_CLIENT_ID, decision.getOAuthClient().getClientId());
127
128 chain.doFilter(request, response);
129 return;
130
131 } catch (OAuthSystemException e1) {
132 throw new ServletException(e1);
133 } catch (OAuthProblemException e) {
134 respondWithError(res, e);
135 return;
136 }
137
138 }
139
140
141 @Override
142 public void destroy() {
143
144 }
145
146 private void respondWithError(HttpServletResponse resp, OAuthProblemException error)
147 throws IOException, ServletException {
148
149 OAuthResponse oauthResponse = null;
150
151 try {
152 if (OAuthUtils.isEmpty(error.getError())) {
153 oauthResponse = OAuthRSResponse.errorResponse(HttpServletResponse.SC_UNAUTHORIZED)
154 .setRealm(realm)
155 .buildHeaderMessage();
156
157 } else {
158
159 int responseCode = 401;
160 if (error.getError().equals(OAuthError.CodeResponse.INVALID_REQUEST)) {
161 responseCode = 400;
162 } else if (error.getError().equals(OAuthError.ResourceResponse.INSUFFICIENT_SCOPE)) {
163 responseCode = 403;
164 }
165
166 oauthResponse = OAuthRSResponse
167 .errorResponse(responseCode)
168 .setRealm(realm)
169 .setError(error.getError())
170 .setErrorDescription(error.getDescription())
171 .setErrorUri(error.getUri())
172 .buildHeaderMessage();
173 }
174 resp.addHeader(OAuth.HeaderType.WWW_AUTHENTICATE,
175 oauthResponse.getHeader(OAuth.HeaderType.WWW_AUTHENTICATE));
176 resp.sendError(oauthResponse.getResponseStatus());
177 } catch (OAuthSystemException e) {
178 throw new ServletException(e);
179 }
180 }
181 }