1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.comm.netty.handler;
20
21 import org.apache.giraph.comm.netty.NettyClient;
22 import org.apache.giraph.comm.netty.SaslNettyClient;
23 import org.apache.giraph.comm.requests.RequestType;
24 import org.apache.giraph.comm.requests.SaslCompleteRequest;
25 import org.apache.giraph.comm.requests.SaslTokenMessageRequest;
26 import org.apache.giraph.comm.requests.WritableRequest;
27 import org.apache.hadoop.conf.Configuration;
28 import org.apache.hadoop.util.ReflectionUtils;
29 import org.apache.log4j.Logger;
30 import io.netty.buffer.ByteBuf;
31 import io.netty.buffer.ByteBufInputStream;
32 import io.netty.channel.ChannelHandlerContext;
33 import io.netty.channel.ChannelInboundHandlerAdapter;
34 import io.netty.handler.codec.FixedLengthFrameDecoder;
35 import io.netty.util.ReferenceCountUtil;
36
37 import java.io.IOException;
38
39
40
41
42
43 public class SaslClientHandler extends ChannelInboundHandlerAdapter {
44
45 private static final Logger LOG = Logger.getLogger(SaslClientHandler.class);
46
47 private final Configuration conf;
48
49
50
51
52
53
54 public SaslClientHandler(Configuration conf) {
55 this.conf = conf;
56 }
57
58 @Override
59 public void channelRead(ChannelHandlerContext ctx, Object msg)
60 throws Exception {
61 WritableRequest decodedMessage = decode(ctx, msg);
62
63 SaslNettyClient saslNettyClient = ctx.attr(NettyClient.SASL).get();
64 if (saslNettyClient == null) {
65 throw new Exception("handleUpstream: saslNettyClient was unexpectedly " +
66 "null for channel: " + ctx.channel());
67 }
68 if (decodedMessage.getClass() == SaslCompleteRequest.class) {
69 if (LOG.isDebugEnabled()) {
70 LOG.debug("handleUpstream: Server has sent us the SaslComplete " +
71 "message. Allowing normal work to proceed.");
72 }
73 synchronized (saslNettyClient.getAuthenticated()) {
74 saslNettyClient.getAuthenticated().notify();
75 }
76 if (!saslNettyClient.isComplete()) {
77 LOG.error("handleUpstream: Server returned a Sasl-complete message, " +
78 "but as far as we can tell, we are not authenticated yet.");
79 throw new Exception("handleUpstream: Server returned a " +
80 "Sasl-complete message, but as far as " +
81 "we can tell, we are not authenticated yet.");
82 }
83
84
85 ctx.pipeline().remove(this);
86 ctx.pipeline().replace("length-field-based-frame-decoder",
87 "fixed-length-frame-decoder",
88 new FixedLengthFrameDecoder(RequestServerHandler.RESPONSE_BYTES));
89 return;
90 }
91 SaslTokenMessageRequest serverToken =
92 (SaslTokenMessageRequest) decodedMessage;
93 if (LOG.isDebugEnabled()) {
94 LOG.debug("handleUpstream: Responding to server's token of length: " +
95 serverToken.getSaslToken().length);
96 }
97
98
99 byte[] responseToServer = saslNettyClient.saslResponse(serverToken);
100 if (responseToServer == null) {
101
102
103 if (LOG.isDebugEnabled()) {
104 LOG.debug("handleUpstream: Response to server is null: " +
105 "authentication should now be complete.");
106 }
107 if (!saslNettyClient.isComplete()) {
108 LOG.warn("handleUpstream: Generated a null response, " +
109 "but authentication is not complete.");
110 }
111 return;
112 } else {
113 if (LOG.isDebugEnabled()) {
114 LOG.debug("handleUpstream: Response to server token has length:" +
115 responseToServer.length);
116 }
117 }
118
119
120 SaslTokenMessageRequest saslResponse =
121 new SaslTokenMessageRequest(responseToServer);
122 ctx.channel().writeAndFlush(saslResponse);
123 }
124
125
126
127
128
129
130
131
132
133 protected WritableRequest decode(ChannelHandlerContext ctx, Object msg)
134 throws Exception {
135 if (!(msg instanceof ByteBuf)) {
136 throw new IllegalStateException("decode: Got illegal message " + msg);
137 }
138
139
140
141
142 ByteBuf buf = (ByteBuf) msg;
143 ByteBufInputStream inputStream = new ByteBufInputStream(buf);
144
145 int enumValue = inputStream.readByte();
146 RequestType type = RequestType.values()[enumValue];
147 if (LOG.isDebugEnabled()) {
148 LOG.debug("decode: Got a response of type " + type + " from server:" +
149 ctx.channel().remoteAddress());
150 }
151
152 Class<? extends WritableRequest> writableRequestClass =
153 type.getRequestClass();
154 WritableRequest serverResponse =
155 ReflectionUtils.newInstance(writableRequestClass, conf);
156
157
158 try {
159 serverResponse.readFields(inputStream);
160 } catch (IOException e) {
161 LOG.error("decode: Exception when trying to read server response: " + e);
162 }
163 ReferenceCountUtil.release(buf);
164
165 return serverResponse;
166 }
167 }