开发者

Matlab - Speeding up a Nested For Loop

开发者 https://www.devze.com 2022-12-15 12:33 出处:网络
A simple question, but I\'m not so great with MATLAB. I have vectors x, (n x 1) y, (m x 1) and w = [x;y]. I want to defi开发者_如何学编程ne M (n+m x 1) as M(i) = number of elements of x that are less

A simple question, but I'm not so great with MATLAB. I have vectors x, (n x 1) y, (m x 1) and w = [x;y]. I want to defi开发者_如何学编程ne M (n+m x 1) as M(i) = number of elements of x that are less than or equal to w(i) (w is sorted). This just isn't cutting it:

N = n + m;
M = zeros(N,1);
for i = 1:N
  for j = 1:n
    if x(j) <= w(i)
      M(i) = M(i) + 1;
    end
  end
end

It's not a particularly smart way to do it, and some of my data vectors m and n are around 100000.

Thanks!


This may look cryptic, but it should give you the same result as your nested loops:

M = histc(-x(:)',[-fliplr(w(:)') inf]);
M = cumsum(fliplr(M(1:N)))';

The above assumes w has been sorted in ascending order.

Explanation

Your sorted vector w can be thought of as bin edges to use in creating a histogram with the HISTC function. Once you count the number of values that fall in each bin (i.e. between the edges), a cumulative sum over those bins using the CUMSUM function will give you your vector M. The reason the above code looks so messy (with negations and the function FLIPLR in it) is because you want to find values in x less than or equal to each value in w, but the function HISTC bins data in the following way:

n(k) counts the value x(i) if edges(k) <= x(i) < edges(k+1).

Notice that less than is used for the upper limit of each bin. You would want to flip the behavior so that you bin according to the rule edges(k) < x(i) <= edges(k+1), which can be achieved by negating the values to be binned, negating the edges, flipping the edges (since the edge input to HISTC must be monotonically nondecreasing), and then flipping the bin counts returned. The value inf is used as an edge value to count everything less than the lowest value in w in the first bin.

If you wanted to find values in x that are simply less than each value in w, the code would be much simpler:

M = histc(x(:)',[-inf w(:)']);
M = cumsum(M(1:N))';


At a minimum the inner loop can be replaced with:

M(i)=sum(x<=w(i))

this will provide substantial performance improvement. You might then consider using arrayfun:

M = arrayfun(@(wi)( sum( x <= wi ) ), w);

arrayfun is less likely to provide substantial gains over the outer for loop but might be worth a try.

edit: I should note that neither w or x need to be sorted for this operation to work correctly.

edit: fwiw, I decided to do some actual performance testing, so I ran this program:

n = 100000; m = n;

N = n + m;

x = rand(n, 1);
w = [x; rand(m, 1)];

tic;
M = zeros(N,1);
for i = 1:N
  for j = 1:n
    if x(j) <= w(i)
      M(i) = M(i) + 1;
    end
  end
end
perf = toc;
fprintf(1, 'Original : %4.3f sec\n', perf);

w = sort(w); % presorted, so don't incur time cost;
tic;
M = histc(-x(:)',[-fliplr(w(:)') inf]);
M = cumsum(fliplr(M(1:N)))';
perf = toc;
fprintf(1, 'gnovice : %4.3f sec\n', perf);

tic;
M = zeros(N,1);
for i = 1:N
    M(i)=sum(x<=w(i));
end
perf = toc;
fprintf(1, 'mine/loop : %4.3f sec\n', perf);

tic;
M = arrayfun(@(wi)( sum( x <= wi ) ), w);
perf = toc;
fprintf(1, 'mine/arrayfun : %4.3f sec\n', perf);

and got these results for n = 1000:

Original : 0.086 sec
gnovice : 0.002 sec
mine/loop : 0.042 sec
mine/arrayfun : 0.070 sec

and for n = 100000:

Original : too long to tell ( >> 1m )
gnovice : 0.050 sec
mine/loop : too long to tell ( >> 1m )
mine/arrayfun : too long to tell ( >> 1m )


try this one instead:

M = sum( bsxfun(@le, w', sort(w)) , 2 )


Haven't done MATLAB for a while, but this should function:

  • Sort x with the inbuilt sort algorithm upwards.

  • Use a loop with wandering index to iterate only one-time over x(j)

    j = 1;
    for i = 1:N
      while j <= n && x(j) <= w(i)
        M(i) = M(i) + 1;
        j = j+1;
      end
    end
    
  • Finally accumulate the sum

    for j =2:n
      M(j) = M(j-1) + M(j)
    end
    


I don't have Matlab in front of me so I can't confirm that this is 100% functional, but you might want to try something like:

for i = 1:N
    M(i) = arrayfun(@(ary,val)length(find(ary <= val)), x, w(i))
end
0

精彩评论

暂无评论...
验证码 换一张
取 消