开发者

A set union find algorithm

开发者 https://www.devze.com 2023-01-03 17:53 出处:网络
I have thousands of lines of 1 to 100 numbers, every line define a group of numbers and a relationship among them.

I have thousands of lines of 1 to 100 numbers, every line define a group of numbers and a relationship among them. I need to get the sets of related numbers.

Little Example: If I have this 7 lines of data

T1 T2
T3 
T4
T5
T6 T1
T5 T4
T3 T4 T7

I need a not so slow algorithm to know that the sets here are:

T1 T2 T6 (because T1 is related with T2 in the first line and T1 related with T6 in the line 5)
T3 T4 T5 T7 (because T5 is with T4 in line 6 and T3 is with T4 and T7 in line 7)

but when you have very big sets is painfully slow to do a search of a T(x) in every big set, and do unions of sets... etc.

Do you have a hint to do this in a not so brute开发者_如何学运维 force manner?

I'm trying to do this in Python.


Once you have built the data structure, exactly what queries do you want to run against it? Show us your existing code. What is a T(x)? You talk about "groups of numbers" but your sample data shows T1, T2, etc; please explain.

Have you read this: http://en.wikipedia.org/wiki/Disjoint-set_data_structure

Try looking at this Python implementation: http://code.activestate.com/recipes/215912-union-find-data-structure/

OR you can lash up something rather simple and understandable yourself, e.g.

[Update: totally new code]

class DisjointSet(object):

    def __init__(self):
        self.leader = {} # maps a member to the group's leader
        self.group = {} # maps a group leader to the group (which is a set)

    def add(self, a, b):
        leadera = self.leader.get(a)
        leaderb = self.leader.get(b)
        if leadera is not None:
            if leaderb is not None:
                if leadera == leaderb: return # nothing to do
                groupa = self.group[leadera]
                groupb = self.group[leaderb]
                if len(groupa) < len(groupb):
                    a, leadera, groupa, b, leaderb, groupb = b, leaderb, groupb, a, leadera, groupa
                groupa |= groupb
                del self.group[leaderb]
                for k in groupb:
                    self.leader[k] = leadera
            else:
                self.group[leadera].add(b)
                self.leader[b] = leadera
        else:
            if leaderb is not None:
                self.group[leaderb].add(a)
                self.leader[a] = leaderb
            else:
                self.leader[a] = self.leader[b] = a
                self.group[a] = set([a, b])

data = """T1 T2
T3 T4
T5 T1
T3 T6
T7 T8
T3 T7
T9 TA
T1 T9"""
# data is chosen to demonstrate each of 5 paths in the code
from pprint import pprint as pp
ds = DisjointSet()
for line in data.splitlines():
    x, y = line.split()
    ds.add(x, y)
    print
    print x, y
    pp(ds.leader)
    pp(ds.group)

and here is the output from the last step:

T1 T9
{'T1': 'T1',
 'T2': 'T1',
 'T3': 'T3',
 'T4': 'T3',
 'T5': 'T1',
 'T6': 'T3',
 'T7': 'T3',
 'T8': 'T3',
 'T9': 'T1',
 'TA': 'T1'}
{'T1': set(['T1', 'T2', 'T5', 'T9', 'TA']),
 'T3': set(['T3', 'T4', 'T6', 'T7', 'T8'])}


Treat your numbers T1, T2, etc. as graph vertices. Any two numbers appearing together on a line are joined by an edge. Then your problem amounts to finding all the connected components in this graph. You can do this by starting with T1, then doing a breadth-first or depth-first search to find all vertices reachable from that point. Mark all these vertices as belonging to equivalence class T1. Then find the next unmarked vertex Ti, find all the yet-unmarked nodes reachable from there, and label them as belonging to equivalence class Ti. Continue until all the vertices are marked.

For a graph with n vertices and e edges, this algorithm requires O(e) time and space to build the adjacency lists, and O(n) time and space to identify all the connected components once the graph structure is built.


You can use a union find data structure to achieve this goal.

The pseudo code for such an algorithm is as follows:

func find( var element )
    while ( element is not the root ) element = element's parent
    return element
end func

func union( var setA, var setB )
    var rootA = find( setA ), rootB = find( setB )
    if ( rootA is equal to rootB ) return
    else
        set rootB as rootA's parent
end func

(Taken from http://www.algorithmist.com/index.php/Union_Find)


As Jim pointed out above, you are essentially looking for the connected components of a simple undirected graph where the nodes are your entities (T1, T2 and so), and edges represent the pairwise relations between them. A simple implementation for connected component search is based on the breadth-first search: you start a BFS from the first entity, find all the related entities, then start another BFS from the first yet unfound entity and so on, until you have found them all. A simple implementation of BFS looks like this:

class BreadthFirstSearch(object):
    """Breadth-first search implementation using an adjacency list"""

    def __init__(self, adj_list):
        self.adj_list = adj_list

    def run(self, start_vertex):
        """Runs a breadth-first search from the given start vertex and
        yields the visited vertices one by one."""
        queue = deque([start_vertex])
        visited = set([start_vertex])
        adj_list = self.adj_list

        while queue:
            vertex = queue.popleft()
            yield vertex
            unseen_neis = adj_list[vertex]-visited
            visited.update(unseen_neis)
            queue.extend(unseen_neis)

def connected_components(graph):
    seen_vertices = set()
    bfs = BreadthFirstSearch(graph)
    for start_vertex in graph:
        if start_vertex in seen_vertices:
            continue
        component = list(bfs.run(start_vertex))
        yield component
        seen_vertices.update(component)

Here, adj_list or graph is an adjacency list data structure, basically it gives you the neighbours of a given vertex in the graph. To build it from your file, you can do this:

adj_list = defaultdict(set)
for line in open("your_file.txt"):
    parts = line.strip().split()
    v1 = parts.pop(0)
    adj_list[v1].update(parts)
    for v2 in parts:
        adj_list[v2].add(v1)

Then you can run:

components = list(connected_components(adj_list))

Of course, implementing the whole algorithm in pure Python tends to be slower than an implementation in C with a more efficient graph data structure. You might consider using igraph or some other graph library like NetworkX to do the job instead. Both libraries contain implementations for connected component search; in igraph, it boils down to this (assuming that your file does not contain lines with single entries, only pairwise entries are accepted):

>>> from igraph import load
>>> graph = load("edge_list.txt", format="ncol", directed=False)
>>> components = graph.clusters()
>>> print graph.vs[components[0]]["name"]
['T1', 'T2', 'T6']
>>> print graph.vs[components[1]]["name"]
['T3', 'T4', 'T5']

Disclaimer: I am one of the authors of igraph


You can model a group using a set. In the example below, I've put the set into a Group class to make it easier to keep references to them and to track some notional 'head' item.

class Group:
    def __init__(self,head):
        self.members = set()
        self.head = head
        self.add(head)
    def add(self,member):
        self.members.add(member)
    def union(self,other):
        self.members = other.members.union(self.members)

groups = {}

for line in open("sets.dat"):
    line = line.split()
    if len(line) == 0:
        break
    # find the group of the first item on the row
    head = line[0]
    if head not in groups:
        group = Group(head)
        groups[head] = group
    else:
        group = groups[head]
    # for each other item on the row, merge the groups
    for node in line[1:]:
        if node not in groups:
            # its a new node, straight into the group
            group.add(node)
            groups[node] = group
        elif head not in groups[node].members:
            # merge two groups
            new_members = groups[node]
            group.union(new_members)
            for migrate in new_members.members:
                groups[migrate] = group
# list them
for k,v in groups.iteritems():
    if k == v.head:
        print v.members

Output is:

set(['T6', 'T2', 'T1'])
set(['T4', 'T5', 'T3'])
0

精彩评论

暂无评论...
验证码 换一张
取 消