1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
|
18 | |
|
19 | |
package org.apache.giraph.comm.netty.handler; |
20 | |
|
21 | |
import org.apache.giraph.comm.netty.NettyClient; |
22 | |
import org.apache.giraph.comm.netty.SaslNettyClient; |
23 | |
import org.apache.giraph.comm.requests.RequestType; |
24 | |
import org.apache.giraph.comm.requests.SaslCompleteRequest; |
25 | |
import org.apache.giraph.comm.requests.SaslTokenMessageRequest; |
26 | |
import org.apache.giraph.comm.requests.WritableRequest; |
27 | |
import org.apache.hadoop.conf.Configuration; |
28 | |
import org.apache.hadoop.util.ReflectionUtils; |
29 | |
import org.apache.log4j.Logger; |
30 | |
import io.netty.buffer.ByteBuf; |
31 | |
import io.netty.buffer.ByteBufInputStream; |
32 | |
import io.netty.channel.ChannelHandlerContext; |
33 | |
import io.netty.channel.ChannelInboundHandlerAdapter; |
34 | |
import io.netty.handler.codec.FixedLengthFrameDecoder; |
35 | |
import io.netty.util.ReferenceCountUtil; |
36 | |
|
37 | |
import java.io.IOException; |
38 | |
|
39 | |
|
40 | |
|
41 | |
|
42 | |
|
43 | |
public class SaslClientHandler extends ChannelInboundHandlerAdapter { |
44 | |
|
45 | 0 | private static final Logger LOG = Logger.getLogger(SaslClientHandler.class); |
46 | |
|
47 | |
private final Configuration conf; |
48 | |
|
49 | |
|
50 | |
|
51 | |
|
52 | |
|
53 | |
|
54 | 0 | public SaslClientHandler(Configuration conf) { |
55 | 0 | this.conf = conf; |
56 | 0 | } |
57 | |
|
58 | |
@Override |
59 | |
public void channelRead(ChannelHandlerContext ctx, Object msg) |
60 | |
throws Exception { |
61 | 0 | WritableRequest decodedMessage = decode(ctx, msg); |
62 | |
|
63 | 0 | SaslNettyClient saslNettyClient = ctx.attr(NettyClient.SASL).get(); |
64 | 0 | if (saslNettyClient == null) { |
65 | 0 | throw new Exception("handleUpstream: saslNettyClient was unexpectedly " + |
66 | 0 | "null for channel: " + ctx.channel()); |
67 | |
} |
68 | 0 | if (decodedMessage.getClass() == SaslCompleteRequest.class) { |
69 | 0 | if (LOG.isDebugEnabled()) { |
70 | 0 | LOG.debug("handleUpstream: Server has sent us the SaslComplete " + |
71 | |
"message. Allowing normal work to proceed."); |
72 | |
} |
73 | 0 | synchronized (saslNettyClient.getAuthenticated()) { |
74 | 0 | saslNettyClient.getAuthenticated().notify(); |
75 | 0 | } |
76 | 0 | if (!saslNettyClient.isComplete()) { |
77 | 0 | LOG.error("handleUpstream: Server returned a Sasl-complete message, " + |
78 | |
"but as far as we can tell, we are not authenticated yet."); |
79 | 0 | throw new Exception("handleUpstream: Server returned a " + |
80 | |
"Sasl-complete message, but as far as " + |
81 | |
"we can tell, we are not authenticated yet."); |
82 | |
} |
83 | |
|
84 | |
|
85 | 0 | ctx.pipeline().remove(this); |
86 | 0 | ctx.pipeline().replace("length-field-based-frame-decoder", |
87 | |
"fixed-length-frame-decoder", |
88 | |
new FixedLengthFrameDecoder(RequestServerHandler.RESPONSE_BYTES)); |
89 | 0 | return; |
90 | |
} |
91 | 0 | SaslTokenMessageRequest serverToken = |
92 | |
(SaslTokenMessageRequest) decodedMessage; |
93 | 0 | if (LOG.isDebugEnabled()) { |
94 | 0 | LOG.debug("handleUpstream: Responding to server's token of length: " + |
95 | 0 | serverToken.getSaslToken().length); |
96 | |
} |
97 | |
|
98 | |
|
99 | 0 | byte[] responseToServer = saslNettyClient.saslResponse(serverToken); |
100 | 0 | if (responseToServer == null) { |
101 | |
|
102 | |
|
103 | 0 | if (LOG.isDebugEnabled()) { |
104 | 0 | LOG.debug("handleUpstream: Response to server is null: " + |
105 | |
"authentication should now be complete."); |
106 | |
} |
107 | 0 | if (!saslNettyClient.isComplete()) { |
108 | 0 | LOG.warn("handleUpstream: Generated a null response, " + |
109 | |
"but authentication is not complete."); |
110 | |
} |
111 | 0 | return; |
112 | |
} else { |
113 | 0 | if (LOG.isDebugEnabled()) { |
114 | 0 | LOG.debug("handleUpstream: Response to server token has length:" + |
115 | |
responseToServer.length); |
116 | |
} |
117 | |
} |
118 | |
|
119 | |
|
120 | 0 | SaslTokenMessageRequest saslResponse = |
121 | |
new SaslTokenMessageRequest(responseToServer); |
122 | 0 | ctx.channel().writeAndFlush(saslResponse); |
123 | 0 | } |
124 | |
|
125 | |
|
126 | |
|
127 | |
|
128 | |
|
129 | |
|
130 | |
|
131 | |
|
132 | |
|
133 | |
protected WritableRequest decode(ChannelHandlerContext ctx, Object msg) |
134 | |
throws Exception { |
135 | 0 | if (!(msg instanceof ByteBuf)) { |
136 | 0 | throw new IllegalStateException("decode: Got illegal message " + msg); |
137 | |
} |
138 | |
|
139 | |
|
140 | |
|
141 | |
|
142 | 0 | ByteBuf buf = (ByteBuf) msg; |
143 | 0 | ByteBufInputStream inputStream = new ByteBufInputStream(buf); |
144 | |
|
145 | 0 | int enumValue = inputStream.readByte(); |
146 | 0 | RequestType type = RequestType.values()[enumValue]; |
147 | 0 | if (LOG.isDebugEnabled()) { |
148 | 0 | LOG.debug("decode: Got a response of type " + type + " from server:" + |
149 | 0 | ctx.channel().remoteAddress()); |
150 | |
} |
151 | |
|
152 | 0 | Class<? extends WritableRequest> writableRequestClass = |
153 | 0 | type.getRequestClass(); |
154 | 0 | WritableRequest serverResponse = |
155 | 0 | ReflectionUtils.newInstance(writableRequestClass, conf); |
156 | |
|
157 | |
|
158 | |
try { |
159 | 0 | serverResponse.readFields(inputStream); |
160 | 0 | } catch (IOException e) { |
161 | 0 | LOG.error("decode: Exception when trying to read server response: " + e); |
162 | 0 | } |
163 | 0 | ReferenceCountUtil.release(buf); |
164 | |
|
165 | 0 | return serverResponse; |
166 | |
} |
167 | |
} |