温馨提示:本文翻译自stackoverflow.com,查看原文请点击:algorithm - How to properly implement disjoint set data structure for finding spanning forests in Python?
algorithm data-structures python disjoint-sets disjoint-union

algorithm - 如何正确实现不相交集数据结构以在Python中查找跨林?

发布于 2020-04-04 00:30:17

最近,我想实现Google kickstater的2019年编程问题的解决方案,并尝试通过分析解释来实现Round E的Cherries Mesh。这是问题和分析的链接。 https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721

这是我实现的代码:

t = int(input())
for k in range(1,t+1):
    n, q = map(int,input().split())
    se = list()
    for _ in range(q):
        a,b = map(int,input().split())
        se.append((a,b))
    l = [{x} for x in range(1,n+1)]
    #print(se)
    for s in se:
        i = 0
        while ({s[0]}.isdisjoint(l[i])):
            i += 1
        j = 0
        while ({s[1]}.isdisjoint(l[j])):
            j += 1
        if i!=j:
            l[i].update(l[j])
            l.pop(j)
        #print(l)
    count = q+2*(len(l)-1)
    print('Case #',k,': ',count,sep='')



这将通过示例案例,但不会通过测试案例。就我所知,这应该是正确的。难道我做错了什么?

查看更多

提问者
user24071
被浏览
146
trincot 2020-01-31 23:56

两个问题:

  • 用于检查边是否链接两个不相交集,如果不相交则将它们联接的算法效率低下。相交集数据结构上的联合查找算法效率更高
  • 最终计数不取决于黑色边缘的原始数量,因为这些黑色边缘可能具有循环,因此不应计算其中的一些。而是计算总共有多少条边(与颜色无关)。由于解表示最小生成树,因此边数为n-1从中减去不相交集的数量(就像您已经做过的)。

我也建议使用有意义的变量名。该代码更容易理解。只有一个字母的变量,喜欢tq或者s,不是非常有帮助。

有几种方法可以实现“联合查找”功能。在这里,我定义了一个Node具有这些方法类:

# Implementation of Union-Find (Disjoint Set)
class Node:
    def __init__(self):
        self.parent = self
        self.rank = 0

    def find(self):
        if self.parent.parent != self.parent:
            self.parent = self.parent.find()
        return self.parent

    def union(self, other):
        node = self.find()
        other = other.find()
        if node == other:
            return True # was already in same set
        if node.rank > other.rank:
            node, other = other, node
        node.parent = other
        other.rank = max(other.rank, node.rank + 1)
        return False # was not in same set, but now is

testcount = int(input())
for testid in range(1, testcount + 1):
    nodecount, blackcount = map(int, input().split())
    # use Union-Find data structure
    nodes = [Node() for _ in range(nodecount)]
    blackedges = []
    for _ in range(blackcount):
        start, end = map(int, input().split())
        blackedges.append((nodes[start - 1], nodes[end - 1]))

    # Start with assumption that all edges on MST are red:
    sugarcount = nodecount * 2 - 2
    for start, end in blackedges:
        if not start.union(end): # When edge connects two disjoint sets:
            sugarcount -= 1 # Use this black edge instead of red one

    print('Case #{}: {}'.format(testid, sugarcount))