Warm tip: This article is reproduced from serverfault.com, please click

Why am I getting this output when parallelising my tree search?

发布于 2020-11-27 09:26:32

I have a binary tree where each node is a 0 or a 1. Each path from root to leaf is a bit string. My code prints out all bit strings sequentially, and it works fine. However, when I try to parallelise it, I am getting unexpected output.

Class Node

public class Node{
  int value;
  Node left, right;
  int depth;

  public Node(int v){
    value = v;
    left = right = null;
  }
}

Sequential Version of Tree.java

import java.util.*;
import java.util.concurrent.*;

public class Tree{
  Node root;
  int levels;
  LinkedList<LinkedList<Integer>> all;

  Tree(int v){
    root = new Node(v);
    levels = 1;
    all = new LinkedList<LinkedList<Integer>>();
  }
  Tree(){
    root = null;
    levels = 0;
  }
  public static void main(String[] args){
    Tree tree = new Tree(0);
    populate(tree, tree.root, tree.levels);
    int processors = Runtime.getRuntime().availableProcessors();
    System.out.println("Available core: "+processors);
//    ForkJoinPool pool = new ForkJoinPool(processors);

    tree.printPaths(tree.root);

//    LinkedList<Integer> path = new LinkedList<Integer>();
//    PrintTask task = new PrintTask(tree.root, path, 0, tree.all);
//    pool.invoke(task);
//    for (int i=0; i < tree.all.size(); i++){
//      System.out.println(tree.all.get(i));
//    }

  }

  public static void populate(Tree t, Node n, int levels){
    levels++;
    if(levels >6){
      n.left = null;
      n.right = null;
    }
    else{
      t.levels = levels;
      n.left = new Node(0);
      n.right = new Node(1);
      populate(t, n.left, levels);
      populate(t, n.right, levels);
    }
  }

  public void printPaths(Node node)
   {
       LinkedList<Integer> path = new LinkedList<Integer>();
       printPathsRecur(node, path, 0);
//       System.out.println("Inside ForkJoin:  "+pool.invoke(new PrintTask(node, path, 0)));
   }

  LinkedList<LinkedList<Integer>> printPathsRecur(Node node, LinkedList<Integer> path, int pathLen)
    {
        if (node == null)
            return null;

        // append this node to the path array
        path.add(node.value);
        path.set(pathLen, node.value);
        pathLen++;

        // it's a leaf, so print the path that led to here
        if (node.left == null && node.right == null){
            printArray(path, pathLen);
            LinkedList<Integer> temp = new LinkedList<Integer>();
            for (int i = 0; i < pathLen; i++){
                temp.add(path.get(i));
            }
            all.add(temp);
        }
        else
        {
            printPathsRecur(node.left, path, pathLen);
            printPathsRecur(node.right, path, pathLen);
        }
        return all;
    }

    // Utility function that prints out an array on a line.
    void printArray(LinkedList<Integer> l, int len){
        for (int i = 0; i < len; i++){
            System.out.print(l.get(i) + " ");
        }
        System.out.println("");
    }
}

This produces the expected output:

0 0 0 0 0 0
0 0 0 0 0 1
0 0 0 0 1 0
0 0 0 0 1 1
...

Then I parallelised Tree.java:

import java.util.*;
import java.util.concurrent.*;

public class Tree{
  Node root;
  int levels;
  LinkedList<LinkedList<Integer>> all;

  Tree(int v){
    root = new Node(v);
    levels = 1;
    all = new LinkedList<LinkedList<Integer>>();
  }
  Tree(){
    root = null;
    levels = 0;
  }
  public static void main(String[] args){
    Tree tree = new Tree(0);
    populate(tree, tree.root, tree.levels);
    int processors = Runtime.getRuntime().availableProcessors();
    System.out.println("Available core: "+processors);
    ForkJoinPool pool = new ForkJoinPool(processors);

//    tree.printPaths(tree.root);

    LinkedList<Integer> path = new LinkedList<Integer>();
    PrintTask task = new PrintTask(tree.root, path, 0, tree.all);
    pool.invoke(task);
    for (int i=0; i < tree.all.size(); i++){
      System.out.println(tree.all.get(i));
    }

  }

  public static void populate(Tree t, Node n, int levels){
    levels++;
    if(levels >6){
      n.left = null;
      n.right = null;
    }
    else{
      t.levels = levels;
      n.left = new Node(0);
      n.right = new Node(1);
      populate(t, n.left, levels);
      populate(t, n.right, levels);
    }
  }
}

and added a task class:

import java.util.concurrent.*;
import java.util.*;

class PrintTask extends RecursiveAction {
  LinkedList<Integer> path = new LinkedList<Integer>();
  Node node;
  int pathLen;
  LinkedList<LinkedList<Integer>> all = new LinkedList<LinkedList<Integer>>();

  PrintTask(Node node, LinkedList<Integer> path, int pathLen, LinkedList<LinkedList<Integer>> all){
    this.node = node;
    this.path = path;
    this.pathLen = pathLen;
    this.all = all;
  }

  protected void compute(){
    if (node == null){
      return;
    }
    path.add(pathLen, node.value);
    pathLen++;

    if(node.left == null && node.right == null){
      printArray(path, pathLen);
      LinkedList<Integer> temp = new LinkedList<Integer>();
      for (int i = 0; i < pathLen; i++){
          temp.add(path.get(i));
      }
      all.add(temp);

    }
    else{
      invokeAll(new PrintTask(node.left, path, pathLen, all), new PrintTask(node.right, path, pathLen, all));

    }
  }
  void printArray(LinkedList<Integer> l, int len){
      for (int i = 0; i < len; i++){
          System.out.print(l.get(i) + " ");
      }
      System.out.println("");
  }

}

And I get this output:

Available core: 8
0 0 1 0 1 1 1 0 0
0 1 1 0 1 1 1 0 1
0 0 1 1 1 0 0
1 1 1 1 0 1
1 0 1 1 0 1 1 1 0 0 1 1 0 0 0 1 1 1 0 1
1 1 1 1 0
0 1
...

[0, 1, 1, 0, 1, 1]
[0, 1, 1, 0, 0, 0]
[0, 1, 1, 0, 0, 1]
[0, 1, 1, 1, 0, 0]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 1, 0]
[0, 1, 1, 1, 0, 0]
...

So, while dynamically printing the path, it seems very different from the expected output where each path was made of 6 bits. In this version, I store all paths in a list of lists, and print the list at the end. It contains some correct looking bit strings, but the problem is that it's not all of them. It only outputs bit strings that start with 011.

Questioner
Ard K.
Viewed
0
Prasanna 2020-11-30 17:51:31

The issue with parallel implementation is due to the below line of code.

invokeAll(new PrintTask(node.left, path, pathLen, all), new PrintTask(node.right, path, pathLen, all));

invokeAll will execute the tasks in parallel. This will result in 2 issues.

  • There is no guarantee left node will be executed before right
  • Race condition can occur in path and pathLen variables that are shared across all the tasks.

Simplest option to correct it is to invoke left and right task in sequence. Like below:

new PrintTask(node.left, path, pathLen, all).invoke();
new PrintTask(node.right, path, pathLen, all).invoke();

But doing so, you loose the benefits of parallel processing and it is as good as executing them sequentially.


To ensure correctness and have parallelism, will make the below changes

  • Change type of all from LinkedList<LinkedList> to LinkedList[]. We will set the size of the array to 2 ^ (levels - 1) to accommodate all the nodes in the tree.
  • Additionally we will introduce an insertIndex variable, so that the leaf nodes will insert the list at the correct index in the result array. We will left shift this insertIndex at each level and for right tree, we will also increment it by 1.
  • We will create 2 new linked list at each level to avoid race condition.

Modified PrintTask:

class PrintTask extends RecursiveAction {
    LinkedList<Integer> path;
    Node node;
    LinkedList[] all;
    int insertIndex;

    PrintTask(Node node, LinkedList<Integer> path, LinkedList[] all, int insertIndex) {
        this.node = node;
        this.path = path;
        this.all = all;
        this.insertIndex = insertIndex;
    }

    protected void compute() {
        if (node == null)
            return;
        path.add(node.value);
        if (node.left == null && node.right == null)
            all[insertIndex] = path;
        else
            invokeAll(new PrintTask(node.left, new LinkedList<>(path), all, insertIndex << 1),
                    new PrintTask(node.right, new LinkedList<>(path), all, (insertIndex << 1) + 1));
    }
}

main() changes:

...
LinkedList[] result = new LinkedList[1 << tree.levels - 1];
PrintTask task = new PrintTask(tree.root, path, result, 0);
pool.invoke(task);
for (LinkedList linkedList : result) 
   System.out.println(linkedList);
...