Coverage Report - org.apache.giraph.comm.SendOneMessageToManyCache
 
Classes in this File Line Coverage Branch Coverage Complexity
SendOneMessageToManyCache
0%
0/119
0%
0/40
0
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one
 3  
  * or more contributor license agreements.  See the NOTICE file
 4  
  * distributed with this work for additional information
 5  
  * regarding copyright ownership.  The ASF licenses this file
 6  
  * to you under the Apache License, Version 2.0 (the
 7  
  * "License"); you may not use this file except in compliance
 8  
  * with the License.  You may obtain a copy of the License at
 9  
  *
 10  
  *     http://www.apache.org/licenses/LICENSE-2.0
 11  
  *
 12  
  * Unless required by applicable law or agreed to in writing, software
 13  
  * distributed under the License is distributed on an "AS IS" BASIS,
 14  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15  
  * See the License for the specific language governing permissions and
 16  
  * limitations under the License.
 17  
  */
 18  
 
 19  
 package org.apache.giraph.comm;
 20  
 
 21  
 import java.io.IOException;
 22  
 import java.util.Arrays;
 23  
 import java.util.Iterator;
 24  
 
 25  
 import javax.annotation.concurrent.NotThreadSafe;
 26  
 
 27  
 import org.apache.giraph.bsp.CentralizedServiceWorker;
 28  
 import org.apache.giraph.comm.netty.NettyWorkerClientRequestProcessor;
 29  
 import org.apache.giraph.comm.requests.SendWorkerMessagesRequest;
 30  
 import org.apache.giraph.comm.requests.SendWorkerOneMessageToManyRequest;
 31  
 import org.apache.giraph.comm.requests.WritableRequest;
 32  
 import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
 33  
 import org.apache.giraph.partition.PartitionOwner;
 34  
 import org.apache.giraph.utils.ByteArrayOneMessageToManyIds;
 35  
 import org.apache.giraph.utils.ExtendedDataOutput;
 36  
 import org.apache.giraph.utils.PairList;
 37  
 import org.apache.giraph.utils.VertexIdMessages;
 38  
 import org.apache.giraph.worker.WorkerInfo;
 39  
 import org.apache.hadoop.io.Writable;
 40  
 import org.apache.hadoop.io.WritableComparable;
 41  
 import org.apache.log4j.Logger;
 42  
 
 43  
 /**
 44  
  * Aggregates the messages to be sent to workers so they can be sent
 45  
  * in bulk.
 46  
  *
 47  
  * @param <I> Vertex id
 48  
  * @param <M> Message data
 49  
  */
 50  
 @NotThreadSafe
 51  
 @SuppressWarnings("unchecked")
 52  
 public class SendOneMessageToManyCache<I extends WritableComparable,
 53  
   M extends Writable> extends SendMessageCache<I, M> {
 54  
   /** Class logger */
 55  0
   private static final Logger LOG =
 56  0
       Logger.getLogger(SendOneMessageToManyCache.class);
 57  
   /** Cache serialized one to many messages for each worker */
 58  
   private final ByteArrayOneMessageToManyIds<I, M>[] msgVidsCache;
 59  
   /** Tracking message-vertexIds sizes for each worker */
 60  
   private final int[] msgVidsSizes;
 61  
   /** Reused byte array to serialize target ids on each worker */
 62  
   private final ExtendedDataOutput[] idSerializer;
 63  
   /** Reused int array to count target id distribution */
 64  
   private final int[] idCounter;
 65  
   /**
 66  
    * Reused int array to record the partition id
 67  
    * of the first target vertex id found on the worker.
 68  
    */
 69  
   private final int[] firstPartitionMap;
 70  
   /** The WorkerInfo list */
 71  
   private final WorkerInfo[] workerInfoList;
 72  
 
 73  
   /**
 74  
    * Constructor
 75  
    *
 76  
    * @param conf Giraph configuration
 77  
    * @param serviceWorker Service worker
 78  
    * @param processor NettyWorkerClientRequestProcessor
 79  
    * @param maxMsgSize Max message size sent to a worker
 80  
    */
 81  
   public SendOneMessageToManyCache(ImmutableClassesGiraphConfiguration conf,
 82  
     CentralizedServiceWorker<?, ?, ?> serviceWorker,
 83  
     NettyWorkerClientRequestProcessor<I, ?, ?> processor,
 84  
     int maxMsgSize) {
 85  0
     super(conf, serviceWorker, processor, maxMsgSize);
 86  0
     int numWorkers = getNumWorkers();
 87  0
     msgVidsCache = new ByteArrayOneMessageToManyIds[numWorkers];
 88  0
     msgVidsSizes = new int[numWorkers];
 89  0
     idSerializer = new ExtendedDataOutput[numWorkers];
 90  
     // InitialBufferSizes is alo initialized based on the number of workers.
 91  
     // As a result, initialBufferSizes is the same as idSerializer in length
 92  0
     int initialBufferSize = 0;
 93  0
     for (int i = 0; i < this.idSerializer.length; i++) {
 94  0
       initialBufferSize = getSendWorkerInitialBufferSize(i);
 95  0
       if (initialBufferSize > 0) {
 96  
         // InitialBufferSizes is from super class.
 97  
         // Each element is for one worker.
 98  0
         idSerializer[i] = conf.createExtendedDataOutput(initialBufferSize);
 99  
       }
 100  
     }
 101  0
     idCounter = new int[numWorkers];
 102  0
     firstPartitionMap = new int[numWorkers];
 103  
     // Get worker info list.
 104  0
     workerInfoList = new WorkerInfo[numWorkers];
 105  
     // Remember there could be null in the array.
 106  0
     for (WorkerInfo workerInfo : serviceWorker.getWorkerInfoList()) {
 107  0
       workerInfoList[workerInfo.getTaskId()] = workerInfo;
 108  0
     }
 109  0
   }
 110  
 
 111  
   /**
 112  
    * Reset ExtendedDataOutput array for id serialization
 113  
    * in next message-Vids encoding
 114  
    */
 115  
   private void resetIdSerializers() {
 116  0
     for (int i = 0; i < this.idSerializer.length; i++) {
 117  0
       if (idSerializer[i] != null) {
 118  0
         idSerializer[i].reset();
 119  
       }
 120  
     }
 121  0
   }
 122  
 
 123  
   /**
 124  
    * Reset id counter for next message-vertexIds encoding
 125  
    */
 126  
   private void resetIdCounter() {
 127  0
     Arrays.fill(idCounter, 0);
 128  0
   }
 129  
 
 130  
   /**
 131  
    * Add message with multiple target ids to message cache.
 132  
    *
 133  
    * @param workerInfo The remote worker destination
 134  
    * @param ids A byte array to hold serialized vertex ids
 135  
    * @param idPos The end position of ids
 136  
    *              information in the byte array above
 137  
    * @param count The number of target ids
 138  
    * @param message Message to send to remote worker
 139  
    * @return The size of messages for the worker.
 140  
    */
 141  
   private int addOneToManyMessage(
 142  
     WorkerInfo workerInfo, byte[] ids, int idPos, int count, M message) {
 143  
     // Get the data collection
 144  0
     ByteArrayOneMessageToManyIds<I, M> workerData =
 145  0
       msgVidsCache[workerInfo.getTaskId()];
 146  0
     if (workerData == null) {
 147  0
       workerData = new ByteArrayOneMessageToManyIds<I, M>(
 148  
           messageValueFactory);
 149  0
       workerData.setConf(getConf());
 150  0
       workerData.initialize(getSendWorkerInitialBufferSize(
 151  0
         workerInfo.getTaskId()));
 152  0
       msgVidsCache[workerInfo.getTaskId()] = workerData;
 153  
     }
 154  0
     workerData.add(ids, idPos, count, message);
 155  
     // Update the size of cached, outgoing data per worker
 156  0
     msgVidsSizes[workerInfo.getTaskId()] =
 157  0
       workerData.getSize();
 158  0
     return msgVidsSizes[workerInfo.getTaskId()];
 159  
   }
 160  
 
 161  
   /**
 162  
    * Gets the messages + vertexIds for a worker and removes it from the cache.
 163  
    * Here the {@link org.apache.giraph.utils.ByteArrayOneMessageToManyIds}
 164  
    * returned could be null.But when invoking this method, we also check if
 165  
    * the data size sent to this worker is above the threshold.
 166  
    * Therefore, it doesn't matter if the result is null or not.
 167  
    *
 168  
    * @param workerInfo Target worker to which one messages - many ids are sent
 169  
    * @return {@link org.apache.giraph.utils.ByteArrayOneMessageToManyIds}
 170  
    *         that belong to the workerInfo
 171  
    */
 172  
   private ByteArrayOneMessageToManyIds<I, M>
 173  
   removeWorkerMsgVids(WorkerInfo workerInfo) {
 174  0
     ByteArrayOneMessageToManyIds<I, M> workerData =
 175  0
       msgVidsCache[workerInfo.getTaskId()];
 176  0
     if (workerData != null) {
 177  0
       msgVidsCache[workerInfo.getTaskId()] = null;
 178  0
       msgVidsSizes[workerInfo.getTaskId()] = 0;
 179  
     }
 180  0
     return workerData;
 181  
   }
 182  
 
 183  
   /**
 184  
    * Gets all messages - vertexIds and removes them from the cache.
 185  
    *
 186  
    * @return All vertex messages for all workers
 187  
    */
 188  
   private PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
 189  
   removeAllMsgVids() {
 190  0
     PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>> allData =
 191  
       new PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>();
 192  0
     allData.initialize(msgVidsCache.length);
 193  0
     for (WorkerInfo workerInfo : getWorkerPartitions().keySet()) {
 194  0
       ByteArrayOneMessageToManyIds<I, M> workerData =
 195  0
         removeWorkerMsgVids(workerInfo);
 196  0
       if (workerData != null && !workerData.isEmpty()) {
 197  0
         allData.add(workerInfo, workerData);
 198  
       }
 199  0
     }
 200  0
     return allData;
 201  
   }
 202  
 
 203  
   @Override
 204  
   public void sendMessageToAllRequest(Iterator<I> vertexIdIterator, M message) {
 205  
     // This is going to be reused through every message sending
 206  0
     resetIdSerializers();
 207  0
     resetIdCounter();
 208  
     // Count messages
 209  0
     int currentMachineId = 0;
 210  0
     PartitionOwner owner = null;
 211  0
     WorkerInfo workerInfo = null;
 212  0
     I vertexId = null;
 213  0
     while (vertexIdIterator.hasNext()) {
 214  0
       vertexId = vertexIdIterator.next();
 215  0
       owner = getServiceWorker().getVertexPartitionOwner(vertexId);
 216  0
       workerInfo = owner.getWorkerInfo();
 217  0
       currentMachineId = workerInfo.getTaskId();
 218  
       // Serialize this target vertex id
 219  
       try {
 220  0
         vertexId.write(idSerializer[currentMachineId]);
 221  0
       } catch (IOException e) {
 222  0
         throw new IllegalStateException(
 223  
           "Failed to serialize the target vertex id.");
 224  0
       }
 225  0
       idCounter[currentMachineId]++;
 226  
       // Record the first partition id in the worker which message send to.
 227  
       // If idCounter shows there is only one target on this worker
 228  
       // then this is the partition number of the target vertex.
 229  0
       if (idCounter[currentMachineId] == 1) {
 230  0
         firstPartitionMap[currentMachineId] = owner.getPartitionId();
 231  
       }
 232  
     }
 233  
     // Add the message to the cache
 234  0
     int idSerializerPos = 0;
 235  0
     int workerMessageSize = 0;
 236  0
     byte[] serializedId  = null;
 237  0
     WritableRequest writableRequest = null;
 238  0
     for (int i = 0; i < idCounter.length; i++) {
 239  0
       if (idCounter[i] == 1) {
 240  0
         serializedId = idSerializer[i].getByteArray();
 241  0
         idSerializerPos = idSerializer[i].getPos();
 242  
         // Add the message to the cache
 243  0
         workerMessageSize = addMessage(workerInfoList[i],
 244  
           firstPartitionMap[i], serializedId, idSerializerPos, message);
 245  
 
 246  0
         if (LOG.isTraceEnabled()) {
 247  0
           LOG.trace("sendMessageToAllRequest: Send bytes (" +
 248  0
             message.toString() + ") to one target in  worker " +
 249  
             workerInfoList[i]);
 250  
         }
 251  0
         ++totalMsgsSentInSuperstep;
 252  0
         if (workerMessageSize >= maxMessagesSizePerWorker) {
 253  
           PairList<Integer, VertexIdMessages<I, M>>
 254  0
             workerMessages = removeWorkerMessages(workerInfoList[i]);
 255  0
           writableRequest = new SendWorkerMessagesRequest<>(workerMessages);
 256  0
           totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
 257  0
           clientProcessor.doRequest(workerInfoList[i], writableRequest);
 258  
           // Notify sending
 259  0
           getServiceWorker().getGraphTaskManager().notifySentMessages();
 260  0
         }
 261  0
       } else if (idCounter[i] > 1) {
 262  0
         serializedId = idSerializer[i].getByteArray();
 263  0
         idSerializerPos = idSerializer[i].getPos();
 264  0
         workerMessageSize = addOneToManyMessage(
 265  
             workerInfoList[i], serializedId, idSerializerPos, idCounter[i],
 266  
             message);
 267  
 
 268  0
         if (LOG.isTraceEnabled()) {
 269  0
           LOG.trace("sendMessageToAllRequest: Send bytes (" +
 270  0
             message.toString() + ") to all targets in worker" +
 271  
             workerInfoList[i]);
 272  
         }
 273  0
         totalMsgsSentInSuperstep += idCounter[i];
 274  0
         if (workerMessageSize >= maxMessagesSizePerWorker) {
 275  0
           ByteArrayOneMessageToManyIds<I, M> workerMsgVids =
 276  0
             removeWorkerMsgVids(workerInfoList[i]);
 277  0
           writableRequest =  new SendWorkerOneMessageToManyRequest<>(
 278  0
             workerMsgVids, getConf());
 279  0
           totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
 280  0
           clientProcessor.doRequest(workerInfoList[i], writableRequest);
 281  
           // Notify sending
 282  0
           getServiceWorker().getGraphTaskManager().notifySentMessages();
 283  
         }
 284  
       }
 285  
     }
 286  0
   }
 287  
 
 288  
   @Override
 289  
   public void flush() {
 290  0
     super.flush();
 291  
     PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
 292  0
     remainingMsgVidsCache = removeAllMsgVids();
 293  
     PairList<WorkerInfo,
 294  
         ByteArrayOneMessageToManyIds<I, M>>.Iterator
 295  0
     msgIdsIterator = remainingMsgVidsCache.getIterator();
 296  0
     while (msgIdsIterator.hasNext()) {
 297  0
       msgIdsIterator.next();
 298  0
       WritableRequest writableRequest =
 299  
         new SendWorkerOneMessageToManyRequest<>(
 300  0
             msgIdsIterator.getCurrentSecond(), getConf());
 301  0
       totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
 302  0
       clientProcessor.doRequest(
 303  0
         msgIdsIterator.getCurrentFirst(), writableRequest);
 304  0
     }
 305  0
   }
 306  
 }