Skip to content

Instantly share code, notes, and snippets.

@sebinsua
Last active August 23, 2023 21:37
Show Gist options
  • Save sebinsua/67f58b8f7e9eb78da79b17d959e9dc73 to your computer and use it in GitHub Desktop.
Save sebinsua/67f58b8f7e9eb78da79b17d959e9dc73 to your computer and use it in GitHub Desktop.
An implementation of a sequential generator called `follow` for traversing linked lists and other data structures in a fixed manner.
from operator import eq, attrgetter as attr
from typing import TypeVar, Callable, Iterator, Optional
eq_or_none = lambda x, y: y is None or eq(x, y)
T = TypeVar('T')
def follow(start: T, next_fn: Callable[[T], T], stop: Callable[[T, T], bool] = eq_or_none) -> Iterator[Optional[T]]:
yield start
x = start
y = next_fn(x)
while not stop(x, y):
x = y
yield x
# We stop iterating when `next_fn` raises a `StopIteration` exception.
# If it does this it means we are exiting the iterable without being
# able to find the next value in a path. In this case we yield `None`
# to indicate this.
try:
y = next_fn(x)
except StopIteration:
yield None
break
def last(iterable: Iterator[Optional[T]]) -> Optional[T]:
value = None
for item in iterable:
value = item
return value
class ListNode:
def __init__(self, value, next_node=None):
self.value = value
self.next = next_node
def __repr__(self):
return f"ListNode(value={self.value})"
def __eq__(self, other):
if not isinstance(other, ListNode):
return NotImplemented
return self.value == other.value
head = ListNode(1, ListNode(2, ListNode(3)))
assert last(follow(head, attr("next"))) == ListNode(3)
representatives = {
0: 4,
1: 5,
4: 7,
7: 8,
8: 11,
11: 11,
}
assert last(follow(0, lambda x: representatives[x])) == 11
class TreeNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def __eq__(self, other):
if not isinstance(other, TreeNode):
return NotImplemented
return self.value == other.value
def __repr__(self):
return f"TreeNode(value={self.value})"
# Example binary search tree
# 5
# / \
# 3 8
# / \ \
# 2 4 10
bst = TreeNode(
5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(8, None, TreeNode(10))
)
def search_bst(bst, value):
def bst_next(node):
if value == node.value:
return node
if node.left and value < node.value:
return node.left
elif node.right and value > node.value:
return node.right
raise StopIteration(
"We have reached the end of the tree without finding the value."
)
return last(follow(bst, bst_next))
# Searching for value 8 in the binary search tree.
assert search_bst(bst, 8) == TreeNode(8)
# Searching for missing value in the binary search tree.
assert search_bst(bst, 17) is None
class CycleDetected(Exception):
pass
def slow_and_fast(head):
def next_fn(tortoise_and_hare):
tortoise, hare = tortoise_and_hare
if hare and (hare.next is None or hare.next.next is None):
return tortoise.next, None
next_tortoise, next_hare = tortoise.next, hare.next.next
if next_tortoise == next_hare:
raise CycleDetected("Cycle detected")
return next_tortoise, next_hare
def stop_fn(_, tortoise_and_hare):
_, hare = tortoise_and_hare
return hare is None
tortoise_and_hare = head, head
return last(follow(tortoise_and_hare, next_fn, stop_fn))
def has_cycle(head):
try:
slow_and_fast(head)
except CycleDetected:
return True
return False
# Find the middle of a linked list (of even and odd lengths)
even_length = ListNode(1, ListNode(2, ListNode(3, ListNode(4, ListNode(5, ListNode(6))))))
mid_even, _ = slow_and_fast(even_length)
assert mid_even == ListNode(3, ListNode(4, ListNode(5, ListNode(6))))
odd_length = ListNode(1, ListNode(2, ListNode(3, ListNode(4, ListNode(5, ListNode(6, ListNode(7)))))))
mid_odd, _ = slow_and_fast(odd_length)
assert mid_odd == ListNode(4, ListNode(5, ListNode(6, ListNode(7))))
# Test cycle detection
head = ListNode(1, ListNode(2, ListNode(3)))
assert has_cycle(head) == False
head.next.next.next = head.next
assert has_cycle(head) == True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment