1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.giraph.edge;
19
20 import it.unimi.dsi.fastutil.bytes.ByteArrays;
21 import it.unimi.dsi.fastutil.longs.LongArrayList;
22
23 import java.io.DataInput;
24 import java.io.DataOutput;
25 import java.io.IOException;
26 import java.util.Arrays;
27 import java.util.BitSet;
28 import java.util.Iterator;
29
30 import javax.annotation.concurrent.NotThreadSafe;
31
32 import org.apache.giraph.utils.ExtendedByteArrayDataInput;
33 import org.apache.giraph.utils.ExtendedByteArrayDataOutput;
34 import org.apache.giraph.utils.ExtendedDataInput;
35 import org.apache.giraph.utils.ExtendedDataOutput;
36 import org.apache.giraph.utils.UnsafeByteArrayInputStream;
37 import org.apache.giraph.utils.UnsafeByteArrayOutputStream;
38 import org.apache.giraph.utils.Varint;
39 import org.apache.hadoop.io.LongWritable;
40 import org.apache.hadoop.io.Writable;
41
42 import com.google.common.base.Preconditions;
43
44
45
46
47
48
49
50
51
52
53
54 @NotThreadSafe
55 public class LongDiffArray implements Writable {
56
57
58
59
60 private byte[] compressedData;
61
62
63
64
65
66
67 private int size;
68
69
70
71
72 private TransientChanges transientData;
73
74
75
76
77 private boolean useUnsafeSerialization = true;
78
79
80
81
82
83 public void setUseUnsafeSerialization(boolean useUnsafeSerialization) {
84 this.useUnsafeSerialization = useUnsafeSerialization;
85 }
86
87
88
89
90
91 public void initialize(int capacity) {
92 reset();
93 if (capacity > 0) {
94 transientData = new TransientChanges(capacity);
95 }
96 }
97
98
99
100
101 public void initialize() {
102 reset();
103 }
104
105
106
107
108
109 public void add(long id) {
110 checkTransientData();
111 transientData.add(id);
112 }
113
114
115
116
117
118
119 public void remove(long id) {
120 checkTransientData();
121
122 if (size > 0) {
123 LongsDiffReader reader = new LongsDiffReader(
124 compressedData,
125 useUnsafeSerialization
126 );
127 for (int i = 0; i < size; i++) {
128 long cur = reader.readNext();
129 if (cur == id) {
130 transientData.markRemoved(i);
131 } else if (cur > id) {
132 break;
133 }
134 }
135 }
136 transientData.removeAdded(id);
137 }
138
139
140
141
142
143 public int size() {
144 int result = size;
145 if (transientData != null) {
146 result += transientData.size();
147 }
148 return result;
149 }
150
151
152
153
154
155 public Iterator<LongWritable> iterator() {
156 trim();
157 return new Iterator<LongWritable>() {
158
159 private int position;
160 private final LongsDiffReader reader =
161 new LongsDiffReader(compressedData, useUnsafeSerialization);
162
163
164 private final LongWritable reusableLong = new LongWritable();
165
166 @Override
167 public boolean hasNext() {
168 return position < size;
169 }
170
171 @Override
172 public LongWritable next() {
173 position++;
174 reusableLong.set(reader.readNext());
175 return reusableLong;
176 }
177
178 @Override
179 public void remove() {
180 removeAt(position - 1);
181 }
182 };
183 }
184
185 @Override
186 public void write(DataOutput out) throws IOException {
187 trim();
188 Varint.writeUnsignedVarInt(compressedData.length, out);
189 Varint.writeUnsignedVarInt(size, out);
190 out.write(compressedData);
191 }
192
193 @Override
194 public void readFields(DataInput in) throws IOException {
195 reset();
196 compressedData = new byte[Varint.readUnsignedVarInt(in)];
197
198
199 size = Varint.readUnsignedVarInt(in);
200 in.readFully(compressedData);
201 }
202
203
204
205
206
207 public void trim() {
208 if (transientData == null) {
209
210 return;
211 }
212
213
214 long[] transientValues = transientData.sortedValues();
215 int pCompressed = 0;
216 int pTransient = 0;
217
218 LongsDiffReader reader = new LongsDiffReader(
219 compressedData,
220 useUnsafeSerialization
221 );
222 LongsDiffWriter writer = new LongsDiffWriter(useUnsafeSerialization);
223
224 long curValue = size > 0 ? reader.readNext() : Long.MAX_VALUE;
225
226
227
228
229 while (pTransient < transientData.numberOfAddedElements() ||
230 pCompressed < size) {
231 if (pTransient < transientData.numberOfAddedElements() &&
232 curValue >= transientValues[pTransient]) {
233 writer.writeNext(transientValues[pTransient]);
234 pTransient++;
235 } else {
236 if (!transientData.isRemoved(pCompressed)) {
237 writer.writeNext(curValue);
238 }
239 pCompressed++;
240 if (pCompressed < size) {
241 curValue = reader.readNext();
242 } else {
243 curValue = Long.MAX_VALUE;
244 }
245 }
246 }
247
248 compressedData = writer.toByteArray();
249 size += transientData.size();
250 transientData = null;
251 }
252
253
254
255
256
257
258
259 private void removeAt(int i) {
260 checkTransientData();
261 if (i < size) {
262 transientData.markRemoved(i);
263 } else {
264 transientData.removeAddedAt(i - size);
265 }
266 }
267
268
269
270
271 private void checkTransientData() {
272 if (transientData == null) {
273 transientData = new TransientChanges();
274 }
275 }
276
277
278
279
280 private void reset() {
281 compressedData = ByteArrays.EMPTY_ARRAY;
282 size = 0;
283 transientData = null;
284 }
285
286
287
288
289 private static class LongsDiffReader {
290
291 private final ExtendedDataInput input;
292
293 private long current;
294
295 private boolean first = true;
296
297
298
299
300
301
302
303 public LongsDiffReader(byte[] compressedData, boolean useUnsafeReader) {
304 if (useUnsafeReader) {
305 input = new UnsafeByteArrayInputStream(compressedData);
306 } else {
307 input = new ExtendedByteArrayDataInput(compressedData);
308 }
309 }
310
311
312
313
314
315 long readNext() {
316 try {
317 if (first) {
318 current = input.readLong();
319 first = false;
320 } else {
321 current += Varint.readUnsignedVarLong(input);
322 }
323 return current;
324 } catch (IOException e) {
325 throw new IllegalStateException(e);
326 }
327 }
328 }
329
330
331
332
333 private static class LongsDiffWriter {
334
335 private final ExtendedDataOutput out;
336
337 private long lastWritten;
338
339 private boolean first = true;
340
341
342
343
344
345 public LongsDiffWriter(boolean useUnsafeWriter) {
346 if (useUnsafeWriter) {
347 out = new UnsafeByteArrayOutputStream();
348 } else {
349 out = new ExtendedByteArrayDataOutput();
350 }
351 }
352
353
354
355
356
357 void writeNext(long value) {
358 try {
359 if (first) {
360 out.writeLong(value);
361 first = false;
362 } else {
363 Preconditions.checkState(value >= lastWritten,
364 "Values need to be in order");
365 Preconditions.checkState((value - lastWritten) >= 0,
366 "In order to use this class, difference of consecutive IDs " +
367 "cannot overflow longs");
368 Varint.writeUnsignedVarLong(value - lastWritten, out);
369 }
370 lastWritten = value;
371 } catch (IOException e) {
372 throw new IllegalStateException(e);
373 }
374 }
375
376
377
378
379
380 byte[] toByteArray() {
381 return out.toByteArray();
382 }
383 }
384
385
386
387
388
389
390
391 private static class TransientChanges {
392
393 private final LongArrayList neighborsAdded;
394
395 private final BitSet removed = new BitSet();
396
397 private int removedCount;
398
399
400
401
402
403 private TransientChanges(int capacity) {
404 neighborsAdded = new LongArrayList(capacity);
405 }
406
407
408
409
410 private TransientChanges() {
411 neighborsAdded = new LongArrayList();
412 }
413
414
415
416
417
418 private void add(long value) {
419 neighborsAdded.add(value);
420 }
421
422
423
424
425
426 private void markRemoved(int index) {
427 if (!removed.get(index)) {
428 removedCount++;
429 removed.set(index);
430 }
431 }
432
433
434
435
436
437 private void removeAddedAt(int index) {
438
439
440
441 if (index == neighborsAdded.size() - 1) {
442 neighborsAdded.popLong();
443 } else {
444 neighborsAdded.set(index, neighborsAdded.popLong());
445 }
446 }
447
448
449
450
451
452 private int numberOfAddedElements() {
453 return neighborsAdded.size();
454 }
455
456
457
458
459
460 private void removeAdded(long target) {
461 neighborsAdded.rem(target);
462 }
463
464
465
466
467
468 private int size() {
469 return neighborsAdded.size() - removedCount;
470 }
471
472
473
474
475
476 private long[] sortedValues() {
477 long[] ret = neighborsAdded.elements();
478 Arrays.sort(ret, 0, neighborsAdded.size());
479 return ret;
480 }
481
482
483
484
485
486
487 private boolean isRemoved(int i) {
488 return removed.get(i);
489 }
490 }
491 }