1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27 package org.apache.hc.core5.testing;
28
29 import java.io.DataInputStream;
30 import java.io.DataOutputStream;
31 import java.io.IOException;
32 import java.io.InputStream;
33 import java.io.OutputStream;
34 import java.net.InetAddress;
35 import java.net.ServerSocket;
36 import java.net.Socket;
37 import java.net.SocketAddress;
38 import java.util.ArrayList;
39 import java.util.List;
40 import java.util.concurrent.locks.ReentrantLock;
41
42 import org.apache.hc.core5.net.InetAddressUtils;
43 import org.apache.hc.core5.util.TimeValue;
44
45
46
47
48 public class SocksProxy {
49
50 private static class SocksProxyHandler {
51
52 public static final int VERSION_5 = 5;
53 public static final int COMMAND_CONNECT = 1;
54 public static final int ATYP_DOMAINNAME = 3;
55
56 private final SocksProxy parent;
57 private final Socket socket;
58 private volatile Socket remote;
59
60 public SocksProxyHandler(final SocksProxy parent, final Socket socket) {
61 this.parent = parent;
62 this.socket = socket;
63 }
64
65 public void start() {
66 new Thread(new Runnable() {
67 @Override
68 public void run() {
69 try {
70 final DataInputStream input = new DataInputStream(socket.getInputStream());
71 final DataOutputStream output = new DataOutputStream(socket.getOutputStream());
72 final Socket target = establishConnection(input, output);
73 remote = target;
74
75 final Thread t1 = pumpStream(input, target.getOutputStream());
76 final Thread t2 = pumpStream(target.getInputStream(), output);
77 try {
78 t1.join();
79 } catch (final InterruptedException e) {
80 }
81 try {
82 t2.join();
83 } catch (final InterruptedException e) {
84 }
85 } catch (final IOException e) {
86 } finally {
87 parent.cleanupSocksProxyHandler(SocksProxyHandler.this);
88 }
89 }
90
91 private Socket establishConnection(final DataInputStream input, final DataOutputStream output) throws IOException {
92 final int clientVersion = input.readUnsignedByte();
93 if (clientVersion != VERSION_5) {
94 throw new IOException("SOCKS implementation only supports version 5");
95 }
96 final int nMethods = input.readUnsignedByte();
97 for (int i = 0; i < nMethods; i++) {
98 input.readUnsignedByte();
99 }
100
101 output.writeByte(VERSION_5);
102 output.writeByte(0);
103 output.flush();
104
105 input.readUnsignedByte();
106 final int command = input.readUnsignedByte();
107 if (command != COMMAND_CONNECT) {
108 throw new IOException("SOCKS implementation only supports CONNECT command");
109 }
110 input.readUnsignedByte();
111
112 final String targetHost;
113 final byte[] targetAddress;
114 final int addressType = input.readUnsignedByte();
115 switch (addressType) {
116 case InetAddressUtils.IPV4:
117 targetHost = null;
118 targetAddress = new byte[4];
119 for (int i = 0; i < targetAddress.length; i++) {
120 targetAddress[i] = input.readByte();
121 }
122 break;
123 case InetAddressUtils.IPV6:
124 targetHost = null;
125 targetAddress = new byte[16];
126 for (int i = 0; i < targetAddress.length; i++) {
127 targetAddress[i] = input.readByte();
128 }
129 break;
130 case ATYP_DOMAINNAME:
131 final int length = input.readUnsignedByte();
132 final StringBuilder domainname = new StringBuilder();
133 for (int i = 0; i < length; i++) {
134 domainname.append((char) input.readUnsignedByte());
135 }
136 targetHost = domainname.toString();
137 targetAddress = null;
138 break;
139 default:
140 throw new IOException("Unsupported address type: " + addressType);
141 }
142
143 final int targetPort = input.readUnsignedShort();
144 final Socket target;
145 if (targetHost != null) {
146 target = new Socket(targetHost, targetPort);
147 } else {
148 target = new Socket(InetAddress.getByAddress(targetAddress), targetPort);
149 }
150
151 output.writeByte(VERSION_5);
152 output.writeByte(0);
153 output.writeByte(0);
154 final byte[] localAddress = target.getLocalAddress().getAddress();
155 if (localAddress.length == 4) {
156 output.writeByte(InetAddressUtils.IPV4);
157 } else if (localAddress.length == 16) {
158 output.writeByte(InetAddressUtils.IPV6);
159 } else {
160 throw new IOException("Unsupported localAddress byte length: " + localAddress.length);
161 }
162 output.write(localAddress);
163 output.writeShort(target.getLocalPort());
164 output.flush();
165
166 return target;
167 }
168
169 private Thread pumpStream(final InputStream input, final OutputStream output) {
170 final Thread t = new Thread(() -> {
171 final byte[] buffer = new byte[1024 * 8];
172 try {
173 while (true) {
174 final int read = input.read(buffer);
175 if (read < 0) {
176 break;
177 }
178 output.write(buffer, 0, read);
179 output.flush();
180 }
181 } catch (final IOException e) {
182 } finally {
183 shutdown();
184 }
185 });
186 t.start();
187 return t;
188 }
189
190 }).start();
191 }
192
193 public void shutdown() {
194 try {
195 this.socket.close();
196 } catch (final IOException e) {
197 }
198 if (this.remote != null) {
199 try {
200 this.remote.close();
201 } catch (final IOException e) {
202 }
203 }
204 }
205
206 }
207
208 private final int port;
209
210 private final List<SocksProxyHandler> handlers = new ArrayList<>();
211 private ServerSocket server;
212 private Thread serverThread;
213 private final ReentrantLock lock;
214
215 public SocksProxy() {
216 this(0);
217 }
218
219 public SocksProxy(final int port) {
220 this.port = port;
221 this.lock = new ReentrantLock();
222 }
223
224 public void start() throws IOException {
225 lock.lock();
226 try {
227 if (this.server == null) {
228 this.server = new ServerSocket(this.port);
229 this.serverThread = new Thread(() -> {
230 try {
231 while (true) {
232 final Socket socket = server.accept();
233 startSocksProxyHandler(socket);
234 }
235 } catch (final IOException e) {
236 } finally {
237 if (server != null) {
238 try {
239 server.close();
240 } catch (final IOException e) {
241 }
242 server = null;
243 }
244 }
245 });
246 this.serverThread.start();
247 }
248 } finally {
249 lock.unlock();
250 }
251 }
252
253 public void shutdown(final TimeValue timeout) throws InterruptedException {
254 final long waitUntil = System.currentTimeMillis() + timeout.toMilliseconds();
255 Thread t = null;
256 lock.lock();
257 try {
258 if (this.server != null) {
259 try {
260 this.server.close();
261 } catch (final IOException e) {
262 } finally {
263 this.server = null;
264 }
265 t = this.serverThread;
266 this.serverThread = null;
267 }
268 for (final SocksProxyHandler handler : this.handlers) {
269 handler.shutdown();
270 }
271 while (!this.handlers.isEmpty()) {
272 final long waitTime = waitUntil - System.currentTimeMillis();
273 if (waitTime > 0) {
274 wait(waitTime);
275 }
276 }
277 } finally {
278 lock.unlock();
279 }
280 if (t != null) {
281 final long waitTime = waitUntil - System.currentTimeMillis();
282 if (waitTime > 0) {
283 t.join(waitTime);
284 }
285 }
286 }
287
288 protected void startSocksProxyHandler(final Socket socket) {
289 final SocksProxyHandler handler = new SocksProxyHandler(this, socket);
290 lock.lock();
291 try {
292 this.handlers.add(handler);
293 } finally {
294 lock.unlock();
295 }
296 handler.start();
297 }
298
299 protected void cleanupSocksProxyHandler(final SocksProxyHandler handler) {
300 lock.lock();
301 try {
302 this.handlers.remove(handler);
303 } finally {
304 lock.unlock();
305 }
306 }
307
308 public SocketAddress getProxyAddress() {
309 return this.server.getLocalSocketAddress();
310 }
311
312 }