开发者

Merge sort from "Programming Scala" causes stack overflow

开发者 https://www.devze.com 2022-12-19 09:09 出处:网络
A direct cut and paste of the following algorithm: def msort[T](less: (T, T) => Boolean) (xs: List[T]): List[T] = {

A direct cut and paste of the following algorithm:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) =>开发者_运维技巧; xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

causes a StackOverflowError on 5000 long lists.

Is there any way to optimize this so that this doesn't occur?


It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.

The latter solution goes like this:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}


Just playing around with scala's TailCalls (trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.

import scala.util.control.TailCalls._

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {

  def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
    if (a.isEmpty) {
      done(b.reverse ::: s)
    } else if (b.isEmpty) {
      done(a.reverse ::: s)
    } else if (a.head<b.head) {
      tailcall(build(a.head::s,a.tail,b))
    } else {
      tailcall(build(b.head::s,a,b.tail))
    }
  }

  build(List(),x,y).result.reverse
}

Runs just as fast as the mutable version on big List[Long]s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).


Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.

Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:

def merge(xs: List[T], ys: List[T]): List[T] = {
  var acc:List[T] = Nil
  var decx = xs
  var decy = ys
  while (!decx.isEmpty || !decy.isEmpty) {
    (decx, decy) match { 
      case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
      case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) { acc = x :: acc ; decx = xs1 }
        else { acc = y :: acc ; decy = ys1 }
    }
  }
  acc.reverse
}

but it keeps track of all the variables for you.

(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)

0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号