By default, MATLAB's sort function deals with ties/repeated elements by prese开发者_高级运维rving the order of the elements, that is
>> [srt,idx] = sort([1 0 1])
srt =
0 1 1
idx =
2 1 3
Note that the two elements with value 1 in the input arbitrarily get assigned index 2 and 3, respectively. idx = [3 1 2], however, would be an equally valid sort.
I would like a function [srt,all_idx] = sort_ties(in) that explicitly returns all possible values for idx that are consistent with the sorted output. Of course this would only happen in the case of ties or repeated elements, and all_idx would be dimension nPossibleSorts x length(in).
I got started on a recursive algorithm for doing this, but quickly realized that things were getting out of hand and someone must have solved this before! Any suggestions?
I had a similar idea to what R. M. suggested. However, this solution is generalized to handle any number of repeated elements in the input vector. The code first sorts the input (using the function SORT), then loops over each unique value to generate all the permutations of the indices for that value (using the function PERMS), storing the results in a cell array. Then these index permutations for each individual value are combined into the total number of permutations for the sorted index by replicating them appropriately with the functions KRON and REPMAT:
function [srt,all_idx] = sort_ties(in,varargin)
[srt,idx] = sort(in,varargin{:});
uniqueValues = srt(logical([1 diff(srt)]));
nValues = numel(uniqueValues);
if nValues == numel(srt)
all_idx = idx;
return
end
permCell = cell(1,nValues);
for iValue = 1:nValues
valueIndex = idx(srt == uniqueValues(iValue));
if numel(valueIndex) == 1
permCell{iValue} = valueIndex;
else
permCell{iValue} = perms(valueIndex);
end
end
nPerms = cellfun('size',permCell,1);
for iValue = 1:nValues
N = prod(nPerms(1:iValue-1));
M = prod(nPerms(iValue+1:end));
permCell{iValue} = repmat(kron(permCell{iValue},ones(N,1)),M,1);
end
all_idx = [permCell{:}];
end
And here are some sample results:
>> [srt,all_idx] = sort_ties([0 2 1 2 2 1])
srt =
0 1 1 2 2 2
all_idx =
1 6 3 5 4 2
1 3 6 5 4 2
1 6 3 5 2 4
1 3 6 5 2 4
1 6 3 4 5 2
1 3 6 4 5 2
1 6 3 4 2 5
1 3 6 4 2 5
1 6 3 2 4 5
1 3 6 2 4 5
1 6 3 2 5 4
1 3 6 2 5 4
Consider the example A=[1,2,3,2,5,6,2]
. You want to find the indices where 2
occurs, and get all possible permutations of those indices.
For the first step, use unique
in combination with histc
to find the repeated element and the indices where it occurs.
uniqA=unique(A);
B=histc(A,uniqA);
You get B=[1 3 1 1 1]
. Now you know which value in uniqA
is repeated and how many times. To get the indices,
repeatIndices=find(A==uniqA(B==max(B)));
which gives the indices as [2, 4, 7]
. Lastly, for all possible permutations of these indices, use the perms
function.
perms(repeatIndices)
ans =
7 4 2
7 2 4
4 7 2
4 2 7
2 4 7
2 7 4
I believe this does what you wanted. You can write a wrapper function around all this so that you have something compact like out=sort_ties(in)
. You probably should include a conditional around the repeatIndices
line, so that if B
is all ones, you don't proceed any further (i.e., there are no ties).
Here is a possible solution I believe to be correct, but it's somewhat inefficient because of the duplicates it generates initially. It's pretty neat otherwise, but I still suspect it can be done better.
function [srt,idx] = tie_sort(in,order)
L = length(in);
[srt,idx] = sort(in,order);
for j = 1:L-1 % for each position in sorted array, look for repeats following it
for k = j+1:L
% if repeat found, add possible permutations to the list of possible sorts
if srt(j) == srt(k)
swapped = 1:L; swapped(j) = k; swapped(k) = j;
add_idx = idx(:,swapped);
idx = cat(1,idx,add_idx);
idx = unique(idx,'rows'); % remove identical copies
else % because already sorted, know don't have to keep looking
break;
end
end
end
精彩评论