Classes in this File | Line Coverage | Branch Coverage | Complexity | ||||
SendGlobalCommCache |
|
| 1.75;1.75 |
1 | /* | |
2 | * Licensed to the Apache Software Foundation (ASF) under one | |
3 | * or more contributor license agreements. See the NOTICE file | |
4 | * distributed with this work for additional information | |
5 | * regarding copyright ownership. The ASF licenses this file | |
6 | * to you under the Apache License, Version 2.0 (the | |
7 | * "License"); you may not use this file except in compliance | |
8 | * with the License. You may obtain a copy of the License at | |
9 | * | |
10 | * http://www.apache.org/licenses/LICENSE-2.0 | |
11 | * | |
12 | * Unless required by applicable law or agreed to in writing, software | |
13 | * distributed under the License is distributed on an "AS IS" BASIS, | |
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
15 | * See the License for the specific language governing permissions and | |
16 | * limitations under the License. | |
17 | */ | |
18 | ||
19 | package org.apache.giraph.comm.aggregators; | |
20 | ||
21 | import java.io.IOException; | |
22 | import java.util.Map; | |
23 | ||
24 | import org.apache.giraph.comm.GlobalCommType; | |
25 | import org.apache.hadoop.io.LongWritable; | |
26 | import org.apache.hadoop.io.Writable; | |
27 | ||
28 | import com.google.common.collect.Maps; | |
29 | ||
30 | /** | |
31 | * Takes and serializes global communication values and keeps them grouped by | |
32 | * owner partition id, to be sent in bulk. | |
33 | * Includes broadcast messages, reducer registrations and special count. | |
34 | */ | |
35 | public class SendGlobalCommCache extends CountingCache { | |
36 | /** Map from worker partition id to global communication output stream */ | |
37 | 0 | private final Map<Integer, GlobalCommValueOutputStream> globalCommMap = |
38 | 0 | Maps.newHashMap(); |
39 | ||
40 | /** whether to write Class object for values into the stream */ | |
41 | private final boolean writeClass; | |
42 | ||
43 | /** | |
44 | * Constructor | |
45 | * | |
46 | * @param writeClass boolean whether to write Class object for values | |
47 | */ | |
48 | 0 | public SendGlobalCommCache(boolean writeClass) { |
49 | 0 | this.writeClass = writeClass; |
50 | 0 | } |
51 | ||
52 | /** | |
53 | * Add global communication value to the cache | |
54 | * | |
55 | * @param taskId Task id of worker which owns the value | |
56 | * @param name Name | |
57 | * @param type Global communication type | |
58 | * @param value Value | |
59 | * @return Number of bytes in serialized data for this worker | |
60 | * @throws IOException | |
61 | */ | |
62 | public int addValue(Integer taskId, String name, | |
63 | GlobalCommType type, Writable value) throws IOException { | |
64 | 0 | GlobalCommValueOutputStream out = globalCommMap.get(taskId); |
65 | 0 | if (out == null) { |
66 | 0 | out = new GlobalCommValueOutputStream(writeClass); |
67 | 0 | globalCommMap.put(taskId, out); |
68 | } | |
69 | 0 | return out.addValue(name, type, value); |
70 | } | |
71 | ||
72 | /** | |
73 | * Remove and get values for certain worker | |
74 | * | |
75 | * @param taskId Partition id of worker owner | |
76 | * @return Serialized global communication data for this worker | |
77 | */ | |
78 | public byte[] removeSerialized(Integer taskId) { | |
79 | 0 | incrementCounter(taskId); |
80 | 0 | GlobalCommValueOutputStream out = globalCommMap.remove(taskId); |
81 | 0 | if (out == null) { |
82 | 0 | return new byte[4]; |
83 | } else { | |
84 | 0 | return out.flush(); |
85 | } | |
86 | } | |
87 | ||
88 | /** | |
89 | * Creates special value which will hold the total number of global | |
90 | * communication requests for worker with selected task id. This should be | |
91 | * called after all values for the worker have been added to the cache. | |
92 | * | |
93 | * @param taskId Destination worker's task id | |
94 | * @throws IOException | |
95 | */ | |
96 | public void addSpecialCount(Integer taskId) throws IOException { | |
97 | // current number of requests, plus one for the last flush | |
98 | 0 | long totalCount = getCount(taskId) + 1; |
99 | 0 | addValue(taskId, GlobalCommType.SPECIAL_COUNT.name(), |
100 | GlobalCommType.SPECIAL_COUNT, new LongWritable(totalCount)); | |
101 | 0 | } |
102 | } |