How can I implement a concurrent quicksort or mergesort algorithm for Java?
We've had issues on a 16-(virtual)-cores Mac where only one core (!) was working using the default Java sorting algo and it was, well, not good to see that very fine machine be completely underused. So we wrote our own (I wrote it) and we did indeed gain good speedups (I wrote a multithreaded quicksort and due to its partitioning nature it parallelize very well but I could have written a mergesort too)... But my implementation only scales up to 4 threads, it's proprietary code, and I'd rather use one coming from a reputable source instead of using my re-invented wheel.
The only one I found on the Web is an example of how not to write a multi-threaded quicksort in Java, it is busy-looping (which is really terrible) using a:
while (helpRequested) { }
http://broadcast.oreilly.com/2009/06/may-column-multithreaded-algor.html
So in addition to losing one thread for no reason it's making sure to kill the perfs by busy-looping in that while loop (which is mindboggling).
Hence my question: do you know of any correctly multithreaded quicksort or mergesort implementation in Java that would be coming from a reputable source?
I put the emphasis on the fact that I know that the complexity stays O(n log n) but I'd still enjoy very much to see all these cores start working instead of idling. Note that for other tasks, on that same 16 virtual cores Mac, I saw speedup of up to x7 by parallelizing the code (and I'm by no mean an expert开发者_运维百科 in concurrency).
So even tough the complexity stays O(n log n), I'd really appreciate a x7 or x8 or even x16 speedup.
give a try to fork/join framework by Doug Lea:
public class MergeSort extends RecursiveAction {
final int[] numbers;
final int startPos, endPos;
final int[] result;
private void merge(MergeSort left, MergeSort right) {
int i=0, leftPos=0, rightPos=0, leftSize = left.size(), rightSize = right.size();
while (leftPos < leftSize && rightPos < rightSize)
result[i++] = (left.result[leftPos] <= right.result[rightPos])
? left.result[leftPos++]
: right.result[rightPos++];
while (leftPos < leftSize)
result[i++] = left.result[leftPos++];
while (rightPos < rightSize)
result[i++] = right.result[rightPos++];
}
public int size() {
return endPos-startPos;
}
protected void compute() {
if (size() < SEQUENTIAL_THRESHOLD) {
System.arraycopy(numbers, startPos, result, 0, size());
Arrays.sort(result, 0, size());
} else {
int midpoint = size() / 2;
MergeSort left = new MergeSort(numbers, startPos, startPos+midpoint);
MergeSort right = new MergeSort(numbers, startPos+midpoint, endPos);
coInvoke(left, right);
merge(left, right);
}
}
}
(source: http://www.ibm.com/developerworks/java/library/j-jtp03048.html?S_TACT=105AGX01&S_CMP=LP)
Java 8 provides java.util.Arrays.parallelSort
, which sorts arrays in parallel using the fork-join framework. The documentation provides some details about the current implementation (but these are non-normative notes):
The sorting algorithm is a parallel sort-merge that breaks the array into sub-arrays that are themselves sorted and then merged. When the sub-array length reaches a minimum granularity, the sub-array is sorted using the appropriate Arrays.sort method. If the length of the specified array is less than the minimum granularity, then it is sorted using the appropriate Arrays.sort method. The algorithm requires a working space no greater than the size of the original array. The ForkJoin common pool is used to execute any parallel tasks.
There does not seem to be a corresponding parallel sort method for lists (even though RandomAccess lists should play nice with sorting), so you'll need to use toArray
, sort that array, and store the result back into the list. (I've asked a question about this here.)
Sorry about this but what you are asking for isn't possible. I believe someone else mentioned that sorting is IO bound and they are most likely correct. The code from IBM by Doug Lea is a nice piece of work but I believe it is intended mostly as an example on how to write code. If you notice in his article he never posted the benchmarks for it and instead posted benchmarks for other working code such as calculating averages and finding the min max in parallel. Here is what the benchmarks are if you use a generic Merge Sort, Quick Sort, Dougs Merge Sort using a Join Fork Pool, and one that I wrote up using a Quick Sort Join Fork Pool. You'll see that Merge Sort is the best for an N of 100 or less. Quick Sort for 1000 to 10000 and the Quick Sort using a Join Fork Pool beats the rest if you have 100000 and higher. These tests were of arrays of random number running 30 time to create an average for each data point and were running on a quad core with about 2 gigs of ram. And below I have the code for the Quick Sort. This mostly shows that unless you're trying to sort a very large array you should back away from trying to improve your codes sort algorithm since the parallel ones run very slow on small N's.
Merge Sort
10 7.51E-06
100 1.34E-04
1000 0.003286269
10000 0.023988694
100000 0.022994328
1000000 0.329776132
Quick Sort
5.13E-05
1.60E-04
7.20E-04
9.61E-04
0.01949271
0.32528383
Merge TP
1.87E-04
6.41E-04
0.003704411
0.014830678
0.019474009
0.19581768
Quick TP
2.28E-04
4.40E-04
0.002716065
0.003115251
0.014046681
0.157845389
import jsr166y.ForkJoinPool;
import jsr166y.RecursiveAction;
// derived from
// http://www.cs.princeton.edu/introcs/42sort/QuickSort.java.html
// Copyright © 2007, Robert Sedgewick and Kevin Wayne.
// Modified for Join Fork by me hastily.
public class QuickSort {
Comparable array[];
static int limiter = 10000;
public QuickSort(Comparable array[]) {
this.array = array;
}
public void sort(ForkJoinPool pool) {
RecursiveAction start = new Partition(0, array.length - 1);
pool.invoke(start);
}
class Partition extends RecursiveAction {
int left;
int right;
Partition(int left, int right) {
this.left = left;
this.right = right;
}
public int size() {
return right - left;
}
@SuppressWarnings("empty-statement")
//void partitionTask(int left, int right) {
protected void compute() {
int i = left, j = right;
Comparable tmp;
Comparable pivot = array[(left + right) / 2];
while (i <= j) {
while (array[i].compareTo(pivot) < 0) {
i++;
}
while (array[j].compareTo(pivot) > 0) {
j--;
}
if (i <= j) {
tmp = array[i];
array[i] = array[j];
array[j] = tmp;
i++;
j--;
}
}
Partition leftTask = null;
Partition rightTask = null;
if (left < i - 1) {
leftTask = new Partition(left, i - 1);
}
if (i < right) {
rightTask = new Partition(i, right);
}
if (size() > limiter) {
if (leftTask != null && rightTask != null) {
invokeAll(leftTask, rightTask);
} else if (leftTask != null) {
invokeAll(leftTask);
} else if (rightTask != null) {
invokeAll(rightTask);
}
}else{
if (leftTask != null) {
leftTask.compute();
}
if (rightTask != null) {
rightTask.compute();
}
}
}
}
}
Just coded up the above MergeSort and performance was very poor.
The code block refers to "coInvoke(left, right);" but there was no reference to this and replaced it with invokeAll(left, right);
Test code is:
MergeSort mysort = new MyMergeSort(array,0,array.length);
ForkJoinPool threadPool = new ForkJoinPool();
threadPool.invoke(mysort);
but had to stop it due to poor performance.
I see that the article above is almost a year old and maybe things have changed now.
I have found the code in the alternative article to work: http://blog.quibb.org/2010/03/jsr-166-the-java-forkjoin-framework/
You probably did consider this, but it might help to look at the concrete problem from a higher level, for example if you don't sort just one array or list it might be much easier to sort individual collections concurrently using the traditional algorithm instead of trying to concurrently sort a single collection.
I've been facing the multithreaded sort problem myself the last couple of days. As explained on this caltech slide the best you can do by simply multithreading each step of the divide and conquer approaches over the obvious number of threads (the number of divisions) is limited. I guess this is because while you can run 64 divisions on 64 threads using all 64 cores of your machine, the 4 divisions can only be run on 4 threads, the 2 on 2, and the 1 on 1, etc. So for many levels of the recursion your machine is under-utilized.
A solution occurred to me last night which might be useful in my own work, so I'll post it here.
Iff, the first criteria of your sorting function is based on an integer of maximum size s, be it an actual integer or a char in a string, such that this integer or char fully defines the highest level of your sort, then I think there's a very fast (and easy) solution. Simply use that initial integer to divide your sorting problem into s smaller sorting problems, and sort those using the standard single threaded sort algo of your choice. The division into s classes can be done in a single pass, I think. There is no merging problem after doing the s independent sorts, because you already know that everything in class 1 sorts before class 2, and so on.
Example : if you wish to do a sort based on strcmp(), then use the first char in your string to break your data into 256 classes, then sort each class on the next available thread until they're all done.
This method fully utilizes all available cores until the problem is solved, and I think it's easy to implement. I haven't implemented it yet though, so there may be problems with it that I have yet to find. It clearly cant work for floating point sorts, and would be inefficient for large s. Its performance would also be heavily dependent on the entropy of the integer/char used to define the classes.
This may be what Fabian Steeg was suggesting in fewer words, but I'm making it explicit that you can create multiple smaller sorts from a larger sort in some circumstances.
import java.util.Arrays;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
public class IQ1 {
public static void main(String[] args) {
// Get number of available processors
int numberOfProcessors = Runtime.getRuntime().availableProcessors();
System.out.println("Number of processors : " + numberOfProcessors);
// Input data, it can be anything e.g. log records, file records etc
long[][] input = new long[][]{
{ 5, 8, 9, 14, 20 },
{ 17, 56, 59, 80, 102 },
{ 2, 4, 7, 11, 15 },
{ 34, 37, 39, 45, 50 }
};
/* A special thread pool designed to work with fork-and-join task splitting
* The pool size is going to be based on number of cores available
*/
ForkJoinPool pool = new ForkJoinPool(numberOfProcessors);
long[] result = pool.invoke(new Merger(input, 0, input.length));
System.out.println(Arrays.toString(result));
}
/* Recursive task which returns the result
* An instance of this will be used by the ForkJoinPool to start working on the problem
* Each thread from the pool will call the compute and the problem size will reduce in each call
*/
static class Merger extends RecursiveTask<long[]>{
long[][] input;
int low;
int high;
Merger(long[][] input, int low, int high){
this.input = input;
this.low = low;
this.high = high;
}
@Override
protected long[] compute() {
long[] result = merge();
return result;
}
// Merge
private long[] merge(){
long[] result = new long[input.length * input[0].length];
int i=0;
int j=0;
int k=0;
if(high - low < 2){
return input[0];
}
// base case
if(high - low == 2){
long[] a = input[low];
long[] b = input[high-1];
result = mergeTwoSortedArrays(a, b);
}
else{
// divide the problem into smaller problems
int mid = low + (high - low) / 2;
Merger first = new Merger(input, low, mid);
Merger second = new Merger(input, mid, high);
first.fork();
long[] secondResult = second.compute();
long[] firstResult = first.join();
result = mergeTwoSortedArrays(firstResult, secondResult);
}
return result;
}
// method to merge two sorted arrays
private long[] mergeTwoSortedArrays(long[] a, long[] b){
long[] result = new long[a.length + b.length];
int i=0;
int j=0;
int k=0;
while(i<a.length && j<b.length){
if(a[i] < b[j]){
result[k] = a[i];
i++;
} else{
result[k] = b[j];
j++;
}
k++;
}
while(i<a.length){
result[k] = a[i];
i++;
k++;
}
while(j<b.length){
result[k] = b[j];
j++;
k++;
}
return result;
}
}
}
The most convenient multi-threading paradigm for a Merge Sort is the fork-join paradigm. This is provided from Java 8 and later. The following code demonstrates a Merge Sort using a fork-join.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = new ArrayList<>(elements);
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return this.elements;
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
List<N> sorted = new ArrayList<>();
while(!left.isEmpty() || !right.isEmpty()) {
if(left.isEmpty())
sorted.add(right.remove(0));
else if(right.isEmpty())
sorted.add(left.remove(0));
else {
if( left.get(0).compareTo(right.get(0)) < 0 )
sorted.add(left.remove(0));
else
sorted.add(right.remove(0));
}
}
return sorted;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,10,1)));
System.out.println("result: " + result);
}
}
While much less straight forward the following varient of the code eliminates the excessive copying of the ArrayList. The initial unsorted list is only created once, and the calls to sublist do not need to perform any copying themselves. Before we would copy the array list each time the algorithm forked. Also, now, when merging lists instead of creating a new list and copying values in it each time we reuse the left list and insert our values into there. By avoiding the extra copy step we improve performance. We use a LinkedList here because inserts are rather cheap compared to an ArrayList. We also eliminate the call to remove, which can be expensive on an ArrayList as well.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return new LinkedList<>(this.elements);
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
int leftIndex = 0;
int rightIndex = 0;
while(leftIndex < left.size() || rightIndex < right.size()) {
if(leftIndex >= left.size())
left.add(leftIndex++, right.get(rightIndex++));
else if(rightIndex >= right.size())
return left;
else {
if( left.get(leftIndex).compareTo(right.get(rightIndex)) < 0 )
leftIndex++;
else
left.add(leftIndex++, right.get(rightIndex++));
}
}
return left;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
System.out.println("result: " + result);
}
}
We can also improve the code one step further by using iterators instead of calling get directly when performing the merge. The reason for this is that get on a LinkedList by index has poor time performance (linear) so by using an iterator we eliminate slow-down caused by internally iterating the linked list on each get. The call to next on an iterator is constant time as opposed to linear time for the call to get. The following code is modified to use iterators instead.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return new LinkedList<>(this.elements);
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
ListIterator<N> leftIter = left.listIterator();
ListIterator<N> rightIter = right.listIterator();
while(leftIter.hasNext() || rightIter.hasNext()) {
if(!leftIter.hasNext()) {
leftIter.add(rightIter.next());
rightIter.remove();
}
else if(!rightIter.hasNext())
return left;
else {
N rightElement = rightIter.next();
if( leftIter.next().compareTo(rightElement) < 0 )
rightIter.previous();
else {
leftIter.previous();
leftIter.add(rightElement);
}
}
}
return left;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
System.out.println("result: " + result);
}
}
Finally the most complex versions of the code, this iteration uses an entirely in-place operation. Only the initial ArrayList is created and no additional collections are ever created. As such the logic is particularly difficult to follow (so i saved it for last). But should be as close to an ideal implementation as we can get.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return this.elements;
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
merge(left, right);
return this.elements;
}
}
private void merge(List<N> left, List<N> right) {
int leftIndex = 0;
int rightIndex = 0;
while(leftIndex < left.size() ) {
if(rightIndex == 0) {
if( left.get(leftIndex).compareTo(right.get(rightIndex)) > 0 ) {
swap(left, leftIndex++, right, rightIndex++);
} else {
leftIndex++;
}
} else {
if(rightIndex >= right.size()) {
if(right.get(0).compareTo(left.get(left.size() - 1)) < 0 )
merge(left, right);
else
return;
}
else if( right.get(0).compareTo(right.get(rightIndex)) < 0 ) {
swap(left, leftIndex++, right, 0);
} else {
swap(left, leftIndex++, right, rightIndex++);
}
}
}
if(rightIndex < right.size() && rightIndex != 0)
merge(right.subList(0, rightIndex), right.subList(rightIndex, right.size()));
}
private void swap(List<N> left, int leftIndex, List<N> right, int rightIndex) {
//N leftElement = left.get(leftIndex);
left.set(leftIndex, right.set(rightIndex, left.get(leftIndex)));
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(new ArrayList<>(Arrays.asList(5,9,8,7,6,1,2,3,4))));
System.out.println("result: " + result);
}
}
Why do you think a parallel sort would help? I'd think most sorting is i/o bound, not processing. Unless your compare does a lot of calculations, a speedup is unlikely.
精彩评论