1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.comm;
20
21 import java.io.IOException;
22 import java.util.Arrays;
23 import java.util.Iterator;
24
25 import javax.annotation.concurrent.NotThreadSafe;
26
27 import org.apache.giraph.bsp.CentralizedServiceWorker;
28 import org.apache.giraph.comm.netty.NettyWorkerClientRequestProcessor;
29 import org.apache.giraph.comm.requests.SendWorkerMessagesRequest;
30 import org.apache.giraph.comm.requests.SendWorkerOneMessageToManyRequest;
31 import org.apache.giraph.comm.requests.WritableRequest;
32 import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
33 import org.apache.giraph.partition.PartitionOwner;
34 import org.apache.giraph.utils.ByteArrayOneMessageToManyIds;
35 import org.apache.giraph.utils.ExtendedDataOutput;
36 import org.apache.giraph.utils.PairList;
37 import org.apache.giraph.utils.VertexIdMessages;
38 import org.apache.giraph.worker.WorkerInfo;
39 import org.apache.hadoop.io.Writable;
40 import org.apache.hadoop.io.WritableComparable;
41 import org.apache.log4j.Logger;
42
43
44
45
46
47
48
49
50 @NotThreadSafe
51 @SuppressWarnings("unchecked")
52 public class SendOneMessageToManyCache<I extends WritableComparable,
53 M extends Writable> extends SendMessageCache<I, M> {
54
55 private static final Logger LOG =
56 Logger.getLogger(SendOneMessageToManyCache.class);
57
58 private final ByteArrayOneMessageToManyIds<I, M>[] msgVidsCache;
59
60 private final int[] msgVidsSizes;
61
62 private final ExtendedDataOutput[] idSerializer;
63
64 private final int[] idCounter;
65
66
67
68
69 private final int[] firstPartitionMap;
70
71 private final WorkerInfo[] workerInfoList;
72
73
74
75
76
77
78
79
80
81 public SendOneMessageToManyCache(ImmutableClassesGiraphConfiguration conf,
82 CentralizedServiceWorker<?, ?, ?> serviceWorker,
83 NettyWorkerClientRequestProcessor<I, ?, ?> processor,
84 int maxMsgSize) {
85 super(conf, serviceWorker, processor, maxMsgSize);
86 int numWorkers = getNumWorkers();
87 msgVidsCache = new ByteArrayOneMessageToManyIds[numWorkers];
88 msgVidsSizes = new int[numWorkers];
89 idSerializer = new ExtendedDataOutput[numWorkers];
90
91
92 int initialBufferSize = 0;
93 for (int i = 0; i < this.idSerializer.length; i++) {
94 initialBufferSize = getSendWorkerInitialBufferSize(i);
95 if (initialBufferSize > 0) {
96
97
98 idSerializer[i] = conf.createExtendedDataOutput(initialBufferSize);
99 }
100 }
101 idCounter = new int[numWorkers];
102 firstPartitionMap = new int[numWorkers];
103
104 workerInfoList = new WorkerInfo[numWorkers];
105
106 for (WorkerInfo workerInfo : serviceWorker.getWorkerInfoList()) {
107 workerInfoList[workerInfo.getTaskId()] = workerInfo;
108 }
109 }
110
111
112
113
114
115 private void resetIdSerializers() {
116 for (int i = 0; i < this.idSerializer.length; i++) {
117 if (idSerializer[i] != null) {
118 idSerializer[i].reset();
119 }
120 }
121 }
122
123
124
125
126 private void resetIdCounter() {
127 Arrays.fill(idCounter, 0);
128 }
129
130
131
132
133
134
135
136
137
138
139
140
141 private int addOneToManyMessage(
142 WorkerInfo workerInfo, byte[] ids, int idPos, int count, M message) {
143
144 ByteArrayOneMessageToManyIds<I, M> workerData =
145 msgVidsCache[workerInfo.getTaskId()];
146 if (workerData == null) {
147 workerData = new ByteArrayOneMessageToManyIds<I, M>(
148 messageValueFactory);
149 workerData.setConf(getConf());
150 workerData.initialize(getSendWorkerInitialBufferSize(
151 workerInfo.getTaskId()));
152 msgVidsCache[workerInfo.getTaskId()] = workerData;
153 }
154 workerData.add(ids, idPos, count, message);
155
156 msgVidsSizes[workerInfo.getTaskId()] =
157 workerData.getSize();
158 return msgVidsSizes[workerInfo.getTaskId()];
159 }
160
161
162
163
164
165
166
167
168
169
170
171
172 private ByteArrayOneMessageToManyIds<I, M>
173 removeWorkerMsgVids(WorkerInfo workerInfo) {
174 ByteArrayOneMessageToManyIds<I, M> workerData =
175 msgVidsCache[workerInfo.getTaskId()];
176 if (workerData != null) {
177 msgVidsCache[workerInfo.getTaskId()] = null;
178 msgVidsSizes[workerInfo.getTaskId()] = 0;
179 }
180 return workerData;
181 }
182
183
184
185
186
187
188 private PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
189 removeAllMsgVids() {
190 PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>> allData =
191 new PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>();
192 allData.initialize(msgVidsCache.length);
193 for (WorkerInfo workerInfo : getWorkerPartitions().keySet()) {
194 ByteArrayOneMessageToManyIds<I, M> workerData =
195 removeWorkerMsgVids(workerInfo);
196 if (workerData != null && !workerData.isEmpty()) {
197 allData.add(workerInfo, workerData);
198 }
199 }
200 return allData;
201 }
202
203 @Override
204 public void sendMessageToAllRequest(Iterator<I> vertexIdIterator, M message) {
205
206 resetIdSerializers();
207 resetIdCounter();
208
209 int currentMachineId = 0;
210 PartitionOwner owner = null;
211 WorkerInfo workerInfo = null;
212 I vertexId = null;
213 while (vertexIdIterator.hasNext()) {
214 vertexId = vertexIdIterator.next();
215 owner = getServiceWorker().getVertexPartitionOwner(vertexId);
216 workerInfo = owner.getWorkerInfo();
217 currentMachineId = workerInfo.getTaskId();
218
219 try {
220 vertexId.write(idSerializer[currentMachineId]);
221 } catch (IOException e) {
222 throw new IllegalStateException(
223 "Failed to serialize the target vertex id.");
224 }
225 idCounter[currentMachineId]++;
226
227
228
229 if (idCounter[currentMachineId] == 1) {
230 firstPartitionMap[currentMachineId] = owner.getPartitionId();
231 }
232 }
233
234 int idSerializerPos = 0;
235 int workerMessageSize = 0;
236 byte[] serializedId = null;
237 WritableRequest writableRequest = null;
238 for (int i = 0; i < idCounter.length; i++) {
239 if (idCounter[i] == 1) {
240 serializedId = idSerializer[i].getByteArray();
241 idSerializerPos = idSerializer[i].getPos();
242
243 workerMessageSize = addMessage(workerInfoList[i],
244 firstPartitionMap[i], serializedId, idSerializerPos, message);
245
246 if (LOG.isTraceEnabled()) {
247 LOG.trace("sendMessageToAllRequest: Send bytes (" +
248 message.toString() + ") to one target in worker " +
249 workerInfoList[i]);
250 }
251 ++totalMsgsSentInSuperstep;
252 if (workerMessageSize >= maxMessagesSizePerWorker) {
253 PairList<Integer, VertexIdMessages<I, M>>
254 workerMessages = removeWorkerMessages(workerInfoList[i]);
255 writableRequest = new SendWorkerMessagesRequest<>(workerMessages);
256 totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
257 clientProcessor.doRequest(workerInfoList[i], writableRequest);
258
259 getServiceWorker().getGraphTaskManager().notifySentMessages();
260 }
261 } else if (idCounter[i] > 1) {
262 serializedId = idSerializer[i].getByteArray();
263 idSerializerPos = idSerializer[i].getPos();
264 workerMessageSize = addOneToManyMessage(
265 workerInfoList[i], serializedId, idSerializerPos, idCounter[i],
266 message);
267
268 if (LOG.isTraceEnabled()) {
269 LOG.trace("sendMessageToAllRequest: Send bytes (" +
270 message.toString() + ") to all targets in worker" +
271 workerInfoList[i]);
272 }
273 totalMsgsSentInSuperstep += idCounter[i];
274 if (workerMessageSize >= maxMessagesSizePerWorker) {
275 ByteArrayOneMessageToManyIds<I, M> workerMsgVids =
276 removeWorkerMsgVids(workerInfoList[i]);
277 writableRequest = new SendWorkerOneMessageToManyRequest<>(
278 workerMsgVids, getConf());
279 totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
280 clientProcessor.doRequest(workerInfoList[i], writableRequest);
281
282 getServiceWorker().getGraphTaskManager().notifySentMessages();
283 }
284 }
285 }
286 }
287
288 @Override
289 public void flush() {
290 super.flush();
291 PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
292 remainingMsgVidsCache = removeAllMsgVids();
293 PairList<WorkerInfo,
294 ByteArrayOneMessageToManyIds<I, M>>.Iterator
295 msgIdsIterator = remainingMsgVidsCache.getIterator();
296 while (msgIdsIterator.hasNext()) {
297 msgIdsIterator.next();
298 WritableRequest writableRequest =
299 new SendWorkerOneMessageToManyRequest<>(
300 msgIdsIterator.getCurrentSecond(), getConf());
301 totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
302 clientProcessor.doRequest(
303 msgIdsIterator.getCurrentFirst(), writableRequest);
304 }
305 }
306 }