开发者

Haskell tail-recursion performance question for Levenshtein distances

开发者 https://www.devze.com 2023-01-18 10:21 出处:网络
I\'m playing around with calculating Levenshtein distances in Haskell, and am a little frustrated with the following performance problem.If you implement it most \'normal\' way for Haskell, like below

I'm playing around with calculating Levenshtein distances in Haskell, and am a little frustrated with the following performance problem. If you implement it most 'normal' way for Haskell, like below (dist), everything works just fine:

dist :: (Ord a) => [a] -> [a] -> Int
dist s1 s2 = ldist s1 s2 (L.length s1, L.length s2)

ldist :: (Ord a) => [a] -> [a] -> (Int, Int) -> Int
ldist _ _ (0, 0) = 0
ldist _ _ (i, 0) = i
ldist _ _ (0, j) = j
ldist s1 s2 (i+1, j+1) = output
  where output | (s1!!(i)) == (s2!!(j)) = ldist s1 s2 (i, j)
               | otherwise = 1 + L.minimum [ldist s1 s2 (i, j)
                                          , ldist s1 s2 (i+1, j)
                                          , ldist s1 s2 (i, j+1)]

But, if you bend your brain a little and implement it as dist', it executes MUCH faster (about 10x).

dist' :: (Ord a) => [a] -> [a] -> Int
dist' o1 o2 = (levenDist o1 o2 [[]])!!0!!0 

levenDist :: (Ord a) => [a] -> [a] -> [[Int]] -> [[Int]]
levenDist s1 s2 arr@([[]]) = levenDist s1 s2 [[0]]
levenDist s1 s2 arr@([]:xs) = levenDist s1 s2 ([(L.length arr) -1]:xs)
levenDist s1 s2 arr@(x:xs) = let
    n1 = L.length s1
    n2 = L.length s2
    n_i = L.length arr
    n_j = L.length x
    match | (s2!!(n_j-1) == s1!!(n_i-2)) = True | otherwise = False
    minCost = if match      then (xs!!0)!!(n2 - n_j + 1) 
                            else L.minimum [(1 + (xs!!0)!!(n2 - n_j + 1))
                                          , (1 + (xs!!0)!!(n2 - n_j + 0))
                                          , (1 + (x!!0))
                                          ]
    dist | (n_i > n1) && (n_j > n2)  = arr 
         | n_j > n2  = []:arr `seq` levenDist s1 s2 $ []:arr
         | n_i == 1 = (n_j:x):xs `seq` levenDist s1 s2 $ (n_j:x):xs
         | otherwise = (minCost:x):xs `seq` levenDist s1 s2 $ (minCost:x):xs
    in dist 

I've tried all the usual seq tricks in the first version, but nothing seems to speed it up. This is a little unsatisfying for me, because I expected the first version to be faster because it doesn't need to evaluate the entire matrix, only the parts it needs.

Does anyone know if it is possible to get these two implementations to perform similarly, or am I just reaping the benefits of tail-recur开发者_如何学Gosion optimizations in the latter, and therefore need to live with its unreadability if I want performance?

Thanks, Orion


In the past I've used this very concise version with foldl and scanl from Wikibooks:

distScan :: (Ord a) => [a] -> [a] -> Int
distScan sa sb = last $ foldl transform [0 .. length sa] sb
  where
    transform xs@(x:xs') c = scanl compute (x + 1) (zip3 sa xs xs')
       where
         compute z (c', x, y) = minimum [y + 1, z + 1, x + fromEnum (c' /= c)]

I just ran this simple benchmark using Criterion:

test :: ([Int] -> [Int] -> Int) -> Int -> Int
test f n = f up up + f up down + f up half + f down half
  where
    up = [1..n]
    half = [1..div n 2]
    down = reverse up

main = let n = 20 in defaultMain
  [ bench "Scan" $ nf (test distScan) n
  , bench "Fast" $ nf (test dist') n
  , bench "Slow" $ nf (test dist) n
  ]

And the Wikibooks version beats both of yours pretty dramatically:

benchmarking Scan
collecting 100 samples, 51 iterations each, in estimated 683.7163 ms...
mean: 137.1582 us, lb 136.9858 us, ub 137.3391 us, ci 0.950

benchmarking Fast
collecting 100 samples, 11 iterations each, in estimated 732.5262 ms...
mean: 660.6217 us, lb 659.3847 us, ub 661.8530 us, ci 0.950...

Slow is still running after a couple of minutes.


To calculate length you need to evaluate the whole list. It is an expensive, O(n), operation. And what's more important, after that the list will be kept in-memory until you stop referencing the list (=> bigger memory footprint). The rule of thumb is not to use length on lists if lists are expected to be long. The same refers to (!!), it goes from the very head of the list every time, so it is O(n) too. Lists are not designed as a random-access data structure.

Better approach with Haskell lists is to consume them partially. Folds are usually the way to go in similar problems. And Levenshtein distance can be calculated that way (see a link below). I don't know if there are better algorithms.

Another approach is to use a different data structure, not lists. For example, if you need random access, known length etc. take a look at Data.Sequence.Seq.

Existing implementations

The second approach has been used in this implementation of the Levenschtein distance in Haskell (using arrays). You can find foldl-based implementation in the first comment there. BTW, foldl' is usually better than foldl.


I don't follow all of your second attempt just yet, but as far as I recall the idea behind the Levenshtein algorithm is to save repeated calculation by using a matrix. In the first piece of code, you are not sharing any calculation and thus you will be repeating lots of calculations. For example, when calculating ldist s1 s2 (5,5) you'll make the calculation for ldist s1 s2 (4,4) at least three separate times (once directly, once via ldist s1 s2 (4,5), once via ldist s1 s2 (5,4)).

What you should do is define an algorithm for generating the matrix (as a list of lists, if you like). I think this is what your second piece of code is doing, but it seems to focus on calculating the matrix in a top-down manner rather than building up the matrix cleanly in an inductive style (the recursive calls in the base case are quite unusual to my eye). Unfortunately I don't have time to write out the whole thing, but thankfully someone else has: look at the first version at this address: http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Haskell

Two more things: one, I'm not sure the Levenshtein algorithm can ever use only part of the matrix anyway, as each entry is dependent on the diagonal, vertical and horizontal neighbour. When you need the value for one corner, you'll inevitably have to evaluate the matrix all the way to the other corner. Secondly, that match | foo = True | otherwise = False line can be replaced by simply match = foo.


It is possible to have an O(N*d) algorithm, where d is the Levenshtein distance. Here's a implementation in Lazy ML by Lloyd Allison which exploits laziness to achieve the improved complexity. This works by only computing part of the matrix, that is, a region around the main diagonal that is proportional in width to the Levenshtein distance.

Edit: I just noticed this has been translated to haskell with a nice image showing which elements of the matrix are computed. This should be significantly faster than the above implementations when the sequences are quite similar. Using the above benchmark:

benchmarking Scan
collecting 100 samples, 100 iterations each, in estimated 1.410004 s
mean: 141.8836 us, lb 141.4112 us, ub 142.5126 us, ci 0.950

benchmarking LAllison.d
collecting 100 samples, 169 iterations each, in estimated 1.399984 s
mean: 82.93505 us, lb 82.75058 us, ub 83.19535 us, ci 0.950


A more intuitive solution using the data-memocombinators package. Credit goes to this answer. Benchmarks are welcome, as all solutions presented here appear to be much, much slower than python-Levenshtein, which was presumably written in C. Note that I tried substituting arrays of chars instead of strings to no effect.

import Data.MemoCombinators (memo2, integral)

levenshtein :: String -> String -> Int
levenshtein a b = levenshtein' (length a) (length b) where
  levenshtein' = memo2 integral integral levenshtein'' where
    levenshtein'' x y -- take x characters from a and y characters from b
      | x==0 = y
      | y==0 = x
      | a !! (x-1) == b !! (y-1) = levenshtein' (x-1) (y-1)
      | otherwise = 1 + minimum [ levenshtein' (x-1) y, 
        levenshtein' x (y-1), levenshtein' (x-1) (y-1) ]
0

精彩评论

暂无评论...
验证码 换一张
取 消