I'm trying to refactor a component that currently produces a Seq[X]
using a fairly expensive recursive algorithm so that it produces a Stream[X]
instead, so X
's can be loaded/calculated on-demand, and the producer doesn't have to try to guess beforehand how much digging it'll have to do to satisfy the consumer.
From what I've read, this is an ideal use for an "unfold", so that's the route I've been trying to take.
Here's my unfold
function, derived from David Pollak's example, which has been vetted by a certain Mr. Morris:
def unfold[T,R](init: T)(f: T => Option[(R,T)]): Stream[R] = f(init) match {
case None => Stream[R]()
case Some((r,v)) => r #:: unfold(v)(f)
}
And here's a little tree to try our luck with:
case class Node[A](data: A, children: List[Node[A]]) {
override def toString = "Node(" + data + ", children=(" +
children.map(_.data).mkString(",") +
"))"
}
val tree = Node("root", List(
Node("/a", List(
Node("/a/1", Nil),
Node("/a/2", Nil)
)),
Node("/b", List(
Node("/b/1", List(
Node("/b/1/x", Nil),
Node("/b/1/y", Nil)
)),
Node("/b/2", List(
Node("/b/2/x", Nil),
Node("/b/2/y", Nil),
Node("/b/2/z", Nil)
))
))
))
And finally, here's my failed attempt at a breadth-first traversal that uses unfold:
val i开发者_JS百科nitial = List(tree)
val traversed = ScalaUtils.unfold(initial) {
case node :: Nil =>
Some((node, node.children))
case node :: nodes =>
Some((node, nodes))
case x =>
None
}
assertEquals(12, traversed.size) // Fails, 8 elements found
/*
traversed foreach println =>
Node(root, children=(/a,/b))
Node(/a, children=(/a/1,/a/2))
Node(/b, children=(/b/1,/b/2))
Node(/b/1, children=(/b/1/x,/b/1/y))
Node(/b/2, children=(/b/2/x,/b/2/y,/b/2/z))
Node(/b/2/x, children=())
Node(/b/2/y, children=())
Node(/b/2/z, children=())
*/
Can anyone give me some hints as to how to fix (or rewrite) my traversal logic so that all the nodes are returned? Thanks!
You just forgot to include the inner nodes' children during the traversal of the tree:
val traversed = unfold(initial) {
case node :: Nil =>
Some((node, node.children))
case node :: nodes =>
// breadth-first
Some((node, nodes ::: node.children))
// or depth-first: Some((node, node.children ::: nodes))
case x =>
None
}
Here is a complete version of Moritz' answer, with a corrected partial function (the last case never matched anything in the original problem):
case class CNode[A](data: A, children: List[CNode[A]]=Nil) {
override def toString: String = if (children.isEmpty) s"node($data)" else
s"node($data, children=(${ children.map(_.data).mkString(",") }))"
}
object Main extends App {
def unfold[T, R](init: T)(f: T => Option[(R, T)]): Stream[R] = f(init) match {
case None => Stream[R]()
case Some((r, v)) => r #:: unfold(v)(f)
}
val tree = List(
CNode("root", List(
CNode("/a", List(
CNode("/a/1", Nil),
CNode("/a/2", Nil)
)),
CNode("/b", List(
CNode("/b/1", List(
CNode("/b/1/x", Nil),
CNode("/b/1/y", Nil)
)),
CNode("/b/2", List(
CNode("/b/2/x", Nil),
CNode("/b/2/y", Nil),
CNode("/b/2/z", Nil)
))
))
))
)
val traversed = unfold(tree) {
case node :: Nil =>
Some((node, node.children))
case node :: nodes =>
// breadth-first
Some((node, nodes ::: node.children))
// or depth-first: Some((node, node.children ::: nodes))
case Nil =>
None
}
println(traversed.force.mkString("\n"))
}
精彩评论