Un Arbre Binaire de Recherche en Python

Ce billet date de plusieurs années, ses informations peuvent être devenues obsolètes.

Alors bande de FDP, vous aimez bien les arbres ?

Ce qui va suivre est l'implémentation expliquée d'un ABR en Python.

Conseil pour comprendre vite : faire des schémas de l'arbre.

Le code est sur GitHub.

Définition d'un ABR

Un arbre binaire de recherche (ABR) est une structure de donnée composée de nœuds. Chaque nœud a au plus 2 enfants ordonnés d'une manière particulière :

  • les enfants à gauche d'un nœud ont des valeurs inférieures à lui
  • les enfants à droite d'un nœud ont des valeurs supérieures à lui

Et cela doit être vrai pour chaque nœud de l'arbre.

Créer l'arbre

Un ABR est composé de nœuds. Chaque nœud contient obligatoirement une valeur et optionnellement un parent et 2 enfants (un à droite et un à gauche) :

class Node:

    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        self.parent = None

    def __str__(self):
        return str(self.data)

Ici, __init__() et __str__() sont des méthodes spéciales de Python.

J'ai choisi de n'utiliser qu'une seule classe Node car si chaque nœud doit respecter les propriétés d'un ABR, alors n'importe quelle méthode devrait pouvoir être exécutée pour chaque nœud.

Dans beaucoup d'exemples que j'ai vu, les gens utilisent une autre classe BST pour envelopper les nœuds et conserver une référence à la racine. D'autres préfèrent garder la classe Node minimale et la manipuler via des fonctions externes. Chacun fait c'qui lui plaît.

Insérer un nœud

Lors de l'ajout d'un nœud, il faut respecter les propriétés de l'arbre. On compare la valeur à insérer avec les valeurs des nœuds existants en commençant par la racine. Pour chaque nœud, si la valeur à insérer lui est inférieure, alors on continue la comparaison en descendant à gauche, sinon à droite. Le dernier nœud inspecté sera le parent du nœud à ajouter. Si le nœud à insérer possède une valeur plus petite que son parent, il deviendra son fils gauche, sinon son fils droit :

def insert(self, data):
    if data < self.data:
        if self.left is None:
            self.left = Node(data)
            self.left.parent = self
        else:
            self.left.insert(data)
    elif data > self.data:
        if self.right is None:
            self.right = Node(data)
            self.right.parent = self
        else:
            self.right.insert(data)

Avec cette implémentation, on choisit de créer un ABR sans doublons.

Pour qu'un arbre existe, il faut qu'il ait une racine. On nomme bst (comme Binary Search Tree) la variable qui contient la racine de l'arbre :

bst = Node(12)  # Root of the BST.
bst.insert(6)
bst.insert(14)
…

Traverser l'arbre

Ça signifie visiter tous ses nœuds. Il existe des méthodes générales de parcours de graphe et des méthodes spécifiques aux arbres binaires basées sur la récursivité.

En traversant notre arbre avec un parcours infixe (ou in-order car il donne les valeurs dans l'ordre croissant) et en conservant la trace du niveau dans lequel nous sommes, on peut en générer une représentation graphique :

def pprint(self, level=0):
    if self.right:
        self.right.pprint(level + 1)
    print(f"{' ' * 4 * level}{self.data}")
    if self.left:
        self.left.pprint(level + 1)

Ça fait pprint du pauvre, mais ça fait le taf si on incline la tête pour lire le résultat :) Vous pouvez aller voir un pprint plus costaud dans la librairie Binarytree.

On va aussi rendre notre classe Node iterable avec __iter__() et lui faire produire une traversée in-order, ça nous servira plus tard :

def __iter__(self):
    if self.left:
        for node in self.left:
            yield node
    yield self.data
    if self.right:
        for node in self.right:
            yield node

Rechercher un nœud

Pour trouver le nœud qui correspond à une valeur, il suffit de comparer cette dernière à celle des autres nœuds en partant de la racine et en respectant les propriétés de l'arbre jusqu'à trouver le bon nœud :

def get(self, data):
    if data < self.data:
        return self.left.get(data) if self.left else None
    elif data > self.data:
        return self.right.get(data) if self.right else None
    return self

On peut emballer get() dans __getitem__() pour trouver les nœuds avec une syntaxe minimale :

def __getitem__(self, key):
    node = self.get(key)
    if node:
        return node
    raise KeyError(key)

On gagne donc en kif d'utilisation :

node = bst.get(9)
node = bst[9]

Trouver le minimum

D'après les propriétés de l'arbre, le nœud ayant la plus petite valeur se trouve forcément le plus à gauche possible :

def min(self):
    node = self
    while node.left:
        node = node.left
    return node

Trouver le maximum

C'est la logique inverse de la recherche du minimum, on descend le plus à droite possible :

def max(self):
    node = self
    while node.right:
        node = node.right
    return node

Compter les enfants d'un nœud

Ici on exploite le fait que les booléens sont un sous-type des entiers en Python :

def count_children(self):
    return bool(self.left) + bool(self.right)

Déterminer de quel bord est l'enfant

L'enfant est-il le fils gauche de son parent ?

def is_left_child(self):
    return self.parent and self is self.parent.left

Ou bien son fils droit ?

def is_right_child(self):
    return self.parent and self is self.parent.right

Trouver le successeur d'un nœud

Le successeur d'un nœud 𝒙 est le nœud ayant la plus petite valeur supérieure à 𝒙. Autrement dit : le nœud ayant la prochaine plus grande valeur, soit le prochain nœud d'une traversée in-order.

Puisque tous les nœuds du sous-arbre gauche de 𝒙 ont une valeur plus petite, son successeur doit se trouver dans son sous-arbre droit. Et toujours d'après les propriétés de l'arbre, ça sera la plus petite valeur de ce sous-arbre droit : son dernier fils gauche s'il y a des descendants à gauche, sinon son premier nœud.

Si le nœud n'a pas de sous-arbre droit, alors il faut remonter dans l'arbre pour trouver une valeur plus grande car c'est le seul endroit où on pourra en trouver une. Le successeur sera le premier des parents tel que 𝒙 apparaît dans son sous-arbre gauche (c'est plus facile à comprendre si on fait un schéma de l'arbre et qu'on a ses propriétés bien en tête) :

def get_successor(self):
    if self.right:
        return self.right.min()
    node = self
    while node.is_right_child():
        node = node.parent
    return node.parent

Si on n'avait pas conservé un lien vers le parent dans notre classe, il aurait fallu scanner tout l'arbre pour trouver le successeur.

Trouver le prédécesseur d'un nœud

C'est la logique inverse de celle de la recherche du successeur :

def get_predecessor(self):
    if self.left:
        return self.left.max()
    node = self
    while node.is_left_child():
        node = node.parent
    return node.parent

Supprimer un nœud

C'est la partie qui me fait suer car on s'embrouille hyper vite entre les nœuds, leurs parents et leurs enfants.

Il y a 3 possibilités principales, parmi lesquelles d'autres possibilités :)

  1. Si le nœud à supprimer n'a pas d'enfants (c'est une feuille), on le supprime :
    • on supprime la branche qui va de son parent vers lui pour en faire un orphelin
    • on supprime le nœud lui-même
  2. Si le nœud à supprimer possède un seul enfant, on le remplace par son fils :
    • si le nœud à supprimer a aussi un parent :
      • on corrige les branches de l'arbre pour aller de son parent vers son fils (et vice-versa) afin de sortir le nœud de l'arbre
      • on supprime le nœud
    • si le nœud à supprimer n'a pas de parent, alors c'est la racine :
      • on remplace la valeur de la racine par celle de son fils
      • on corrige les branches de l'arbre pour aller de la racine vers son petit-fils (et vice-versa) afin que le fils ne fasse plus partie de l'arbre
      • on supprime le fils
  3. Si le nœud à supprimer possède 2 enfants, on le remplace par son successeur :
    • son successeur sera forcément dans son sous-arbre droit car le nœud a 2 enfants, donc un sous-arbre droit CQFD
    • si son successeur a des enfants, ils ne pourront être qu'à sa droite car son successeur est dans notre cas la plus petite valeur du sous-arbre droit
    • on remplace sa valeur par celle de son successeur
    • on corrige les branches de l'arbre pour aller de notre nœud vers le fils de son successeur (et vice-versa) pour rendre le successeur orphelin et le sortir de l'arbre
    • on supprime son successeur

Les bouquins d'algorithmes décrivent une autre façon de faire avec deux méthodes splice() et remove_node() mais je trouve qu'un seul bloc est plus facile à comprendre :

def delete(self, data):

    node = self.get(data)

    if not node:
        return

    children_count = node.count_children()

    if children_count == 0:
        if node.is_left_child():
            node.parent.left = None
        else:
            node.parent.right = None
        del node

    elif children_count == 1:
        child = node.left or node.right
        if node.is_left_child():
            node.parent.left = child
            child.parent = node.parent
            del node
        elif node.is_right_child():
            node.parent.right = child
            child.parent = node.parent
            del node
        else:
            root = node
            root.data = child.data
            root.left = child.left
            root.right = child.right
            if child.left:
                child.left.parent = root
            if child.right:
                child.right.parent = root
            del child

    else:
        succ = node.get_successor()
        node.data = succ.data
        if succ.is_left_child():
            succ.parent.left = succ.right
        else:
            succ.parent.right = succ.right
        if succ.right:
            succ.right.parent = succ.parent
        del succ

Calculer la hauteur

Pour la déterminer, il faut se mettre d'accord sur la définition de la hauteur d'un arbre. Pour certains c'est le nombre de nœuds qui jalonnent le chemin le plus long dans l'arbre. Pour d'autres c'est plutôt le nombre de branches (on dit aussi arêtes ou edges en théorie des graphes) du chemin le plus long.

On va utiliser cette dernière définition. Pour l'implémenter on utilise la récursivité, l'évaluation commence alors par les nœuds sans enfants les plus bas dans l'arbre (les feuilles) :

  • la feuille de gauche retourne 1 + max(-1, -1) (soit 0) à son appelant
  • son appelant retourne par exemple 1 + max(0, -1) (soit 1) à son propre appelant
  • etc.

Si vous préférez un calcul de hauteur basé sur le nombre de nœuds, il suffit de remplacer -1 par 0 :

def get_height(self):
    return 1 + max(
        self.left.get_height() if self.left else -1,
        self.right.get_height() if self.right else -1
    )

Vérifier l'équilibre

Un arbre équilibré permet de conserver une complexité algorithmique constante pour certaines opérations. Or l'équilibre de l'ABR n'est pas garanti. Il existe d'autres structures d'arbres pour ça.

Pourtant on vous demande dans certains tests d'embauche de vérifier que votre ABR est équilibré… Soit…

La définition communément admise de l'équilibre d'un arbre est que la hauteur des deux sous-arbres de chaque nœud ne peut excéder 1.

Il y a plusieurs techniques pour vérifier l'équilibre. Le code ci-dessous est inspiré de celui de Cracking the Coding Interview sauf que je lève une exception pour déclencher le déroulement de pile plutôt que d'utiliser une variable avec un code d'erreur, je trouve ça plus lisible :

def _check_balance(self):
    left = self.left._check_balance() if self.left else -1
    right = self.right._check_balance() if self.right else -1
    if abs(left - right) > 1:
        raise ValueError('Unbalanced tree.')
    return max(left, right) + 1

def is_balanced(self):
    try:
        self._check_balance()
        return True
    except ValueError:
        return False

Comme _check_balance() ne va être utilisée que par is_balanced(), je la marque comme méthode privée avec un caractère underscore au début de son nom, c'est une convention. En vrai il n'y a rien de privé en Python.

Valider que notre arbre est bien un ABR

Puisque notre implémentation n'admet pas de doublons, on peut simplement faire une traversée in-order et vérifier que les valeurs sont effectivement dans un ordre croissant.

Je vous avais bien dit que le fait de rendre notre classe Node iterable allait nous servir. En plus __iter__() utilise yield, ce qui va nous permettre de ne pas gâcher d'espace mémoire :

def is_valid(self):
    prev = None
    for data in self:
        if prev and prev > data:
            return False
        prev = data
    return True

Une autre façon de faire est d'utiliser la récursivité pour comparer les minimums et maximums de chaque sous-arbre. En voici une chouette implémentation en Python.

Sources d'inspiration

Avant Quelques algorithmes de tri en Python Après Parcours de graphes en Python

Tag Kemar Joint