开发者

Why does the same algorithm work in Scala much slower than in C#? And how to make it faster?

开发者 https://www.devze.com 2023-02-26 21:36 出处:网络
The algorithm creates all possible variants of the sequence from variants for each member of the sequence.

The algorithm creates all possible variants of the sequence from variants for each member of the sequence.

C# code :

static void Main(string[] args)
{
  var arg = new List<List<int>>();
  int i = 0;
  for (int j = 0; j < 5; j++)
  {
    arg.Add(new List<int>());
    for (int j1 = i; j1 < i + 3; j1++)
    {
      //if (j1 != 5)
      arg[j].Add(j1);
    }
    i += 3;
  }
  List<Utils<int>.Variant<int>> b2 = new List<Utils<int>.Variant<int>>();
  //int[][] bN;

  var s = System.Diagnostics.Stopwatch.StartNew();
  //for(int j = 0; j < 10;j++)
    b2 = Utils<int>.Produce2(arg);
  s.Stop();
  Console.WriteLine(s.ElapsedM开发者_JAVA百科illiseconds);

}


public class Variant<T>
{
  public T element;
  public Variant<T> previous;
}


public static List<Variant<T>> Produce2(List<List<T>> input)
{
  var ret = new List<Variant<T>>();
  foreach (var form in input)
  {
    var newRet = new List<Variant<T>>(ret.Count * form.Count);
    foreach (var el in form)
    {
      if (ret.Count == 0)
      {
        newRet.Add(new Variant<T>{ element = el, previous = null });
      }
      else
      {
        foreach (var variant in ret)
        {
          var buf = new Variant<T> { previous = variant, element = el };
          newRet.Add(buf);
        }
      }
    }
    ret = newRet;
  }
  return ret;
}

Scala code :

object test {
def main() {
 var arg = new Array[Array[Int]](5)
 var i = 0
 var init = 0
 while (i<5)
 {
  var buf = new Array[Int](3)
  var j = 0
  while (j<3)
  {
   buf(j) = init
   init = init+1
   j = j + 1
  }
  arg(i)=buf
  i = i + 1
 }
 println("Hello, world!")
 val start = System.currentTimeMillis
 var res = Produce(arg)
 val stop = System.currentTimeMillis
 println(stop-start)
 /*for(list <- res)
  {
   for(el <- list)
    print(el+" ")
   println
  }*/
 println(res.length)
}

 def Produce[T](input:Array[Array[T]]):Array[Variant[T]]=
  {
   var ret = new Array[Variant[T]](1)
   for(val forms <- input)
   {
    if(forms!=null)
    {
     var newRet = new Array[Variant[T]](forms.length*ret.length)
     if(ret.length>0)
     {
      for(val prev <-ret)
       if(prev!=null)
       for(val el <-forms)
       {
        newRet = newRet:+new Variant[T](el,prev)
       }
     }
     else
     {
      for(val el <- forms)
        {
         newRet = newRet:+new Variant[T](el,null)
        }
     }
     ret = newRet
    }
   }
   return ret
  }


}

class Variant[T](var element:T, previous:Variant[T])
{
}


As others have said, the difference is in how you're using the collections. Array in Scala is the same thing as Java's primitive array, [], which is the same as C#'s primitive array []. Scala is clever enough to do what you ask (namely, copy the entire array with a new element on the end), but not so clever as to tell you that you'd be better off using a different collection. For example, if you just change Array to ArrayBuffer it should be much faster (comparable to C#).

Actually, though, you'd be better off not using for loops at all. One of the strengths of Scala's collections library is that you have a wide variety of powerful operations at your disposal. In this case, you want to take every item from forms and convert it into a Variant. That's what map does.

Also, your Scala code doesn't seem to actually work.

If you want all possible variants from each member, you really want to use recursion. This implementation does what you say you want:

object test {
  def produce[T](input: Array[Array[T]], index: Int = 0): Array[List[T]] = {
    if (index >= input.length) Array()
    else if (index == input.length-1) input(index).map(elem => List(elem))
    else {
      produce(input, index+1).flatMap(variant => {
        input(index).map(elem => elem :: variant)
      })
    }
  }

  def main() {
    val arg = Array.tabulate(5,3)((i,j) => i*3+j)
    println("Hello, world!")
    val start = System.nanoTime
    var res = produce(arg)
    val stop = System.nanoTime
    println("Time elapsed (ms): " + (stop-start)/1000000L)
    println("Result length: " + res.length)
    println(res.deep)
  }
}

Let's unpack this a little. First, we've replaced your entire construction of the initial variants with a single tabulate instruction. tabulate takes a target size (5x3, here), and then a function that maps from the indices into that rectangle into the final value.

We've also made produce a recursive function. (Normally we'd make it tail-recursive, but let's keep things as simple as we can for now.) How do you generate all variants? Well, all variants is clearly (every possibility at this position) + (all variants from later positions). So we write that down recursively.

Note that if we build variants recursively like this, all the tails of the variants end up the same, which makes List a perfect data structure: it's a singly-linked immutable list, so instead of having to copy all those tails over and over again, we just point to them.

Now, how do we actually do the recursion? Well, if there's no data at all, we had better return an empty array (i.e. if index is past the end of the array). If we're on the last element of the array of variations, we basically want each element to turn into a list of length 1, so we use map to do exactly that (elem => List(elem)). Finally, if we are not at the end, we get the results from the rest (which is produce(input, index+1)) and make variants with each element.

Let's take the inner loop first: input(index).map(elem => elem :: variant). This takes each element from variants in position index and sticks them onto an existing variant. So this will give us a new batch of variants. Fair enough, but where do we get the new variant from? We produce it from the rest of the list: produce(input, index+1), and then the only trick is that we need to use flatMap--this takes each element, produces a collection out of it, and glues all those collections together.

I encourage you to throw printlns in various places to see what's going on.

Finally, note that with your test size, it's actually an insigificant amount of work; you can't accurately measure that, even if you switch to using the more accurate System.nanoTime as I did. You'd need something like tabulate(12,3) before it gets significant (500,000 variants produced).


The :+ method of the Array (more precisely of ArrayOps) will always create a copy of the array. So instead of a constant time operation you have one that is more or less O(n). You do it within nested cycles => your whole stuff will be an order of magnitude slower.

This way you more or less emulate an immutable data structure with a mutable one (which was not designed for it).

To fix it you can either use Array as a mutable data structure (but then try to avoid endless copying), or you can switch to a immutable one. I did not check your code very carefully, but the first bet is usually List, check the scaladoc of the various methods to see their performance behaviour.


ret.length is not 0 all the time, right before return it is 243. The size of array should not be changed, and List in .net is an abstraction on top of array. BUT thank you for the point - problem was that I used :+ operator with array which as I understand caused implicit use of type LinkedList

0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号