1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.examples;
20
21 import com.google.common.base.Preconditions;
22 import org.apache.giraph.edge.Edge;
23 import org.apache.giraph.graph.Vertex;
24 import org.apache.giraph.utils.MathUtils;
25 import org.apache.hadoop.io.DoubleWritable;
26 import org.apache.hadoop.io.LongWritable;
27
28
29
30
31
32
33 public class RandomWalkWithRestartComputation
34 extends RandomWalkComputation<DoubleWritable> {
35
36
37 static final String SOURCE_VERTEX = RandomWalkWithRestartComputation.class
38 .getName() + ".sourceVertex";
39
40
41
42
43
44
45 private boolean isSourceVertex(Vertex<LongWritable, ?, ?> vertex) {
46 return ((RandomWalkWorkerContext) getWorkerContext()).isSource(
47 vertex.getId().get());
48 }
49
50
51
52
53
54 private int numSourceVertices() {
55 return ((RandomWalkWorkerContext) getWorkerContext()).numSources();
56 }
57
58 @Override
59 protected double transitionProbability(
60 Vertex<LongWritable, DoubleWritable, DoubleWritable>
61 vertex,
62 double stateProbability, Edge<LongWritable, DoubleWritable> edge) {
63 return stateProbability * edge.getValue().get();
64 }
65
66 @Override
67 protected double recompute(
68 Vertex<LongWritable, DoubleWritable, DoubleWritable> vertex,
69 Iterable<DoubleWritable> transitionProbabilities,
70 double teleportationProbability) {
71 int numSourceVertices = numSourceVertices();
72 Preconditions.checkState(numSourceVertices > 0, "No source vertex found");
73
74 double stateProbability = MathUtils.sum(transitionProbabilities);
75
76
77 stateProbability += getDanglingProbability() / getTotalNumVertices();
78
79 stateProbability *= 1 - teleportationProbability;
80 if (isSourceVertex(vertex)) {
81 stateProbability += teleportationProbability / numSourceVertices;
82 }
83 return stateProbability;
84 }
85 }