advent-of-code/2024/day16/day16.py
2024-12-17 21:23:46 +01:00

118 lines
3.1 KiB
Python

from dataclasses import dataclass
from heapq import heappop, heappush
DIRECTIONS = {
">": 1 + 0j, # EAST
"v": 0 + 1j, # SOUTH
"<": -1 + 0j, # WEST
"^": 0 - 1j, # NORTH
}
DIRECTIONS_RV = {v: k for k,v in DIRECTIONS.items()}
@dataclass(frozen=True)
class Node:
pos: complex
direction: complex
cost: int = 0
parent: "Node" = None
def __lt__(self, other):
return self.cost < other.cost
def find_start_and_goal(grid):
start, goal = None, None
for y, row in enumerate(grid):
for x, _ in enumerate(row):
if grid[y][x] == "S":
start = complex(x, y)
elif grid[y][x] == "E":
goal = complex(x, y)
return start, goal
def get_pos(grid, pos):
x, y = int(pos.real), int(pos.imag)
if 0 <= x < len(grid[0]) and 0 <= y < len(grid):
return grid[y][x]
return None
def search(grid, start_node, end_pos):
"""
Returns the shortest path between start and end using Dijkstra's algorithm
"""
queue = [start_node]
visited = set()
best_costs = {}
while queue != []:
node = heappop(queue)
visited.add(node)
if node.pos == end_pos:
return node
if node.cost > best_costs.get((node.pos, node.direction), 99999999): # already found a better path to this pos
continue
best_costs[(node.pos, node.direction)] = node.cost
# visit each neighbor
# go in the same direction
n1 = Node(node.pos + node.direction, node.direction, node.cost + 1, node)
if get_pos(grid, n1.pos) != "#" and n1 not in visited:
heappush(queue, n1)
# turn clockwise
turned = node.direction * 1j
n2 = Node(node.pos + turned, turned, node.cost + 1000 + 1, node)
if get_pos(grid, n2.pos) != "#" and n2 not in visited:
heappush(queue, n2)
# turn counterclockwise
turned = node.direction * -1j
n3 = Node(node.pos + turned, turned, node.cost + 1000 + 1, node)
if get_pos(grid, n3.pos) != "#" and n3 not in visited:
heappush(queue, n3)
return None
def print_grid(grid):
for row in grid:
print("".join(row))
def main(grid, debug=False):
start, goal = find_start_and_goal(grid)
direction = 1 + 0j # initial direction is east
end_node = search(grid, Node(start, direction), goal)
total_cost = end_node.cost
print("Part 1: ", total_cost)
if debug:
# compute path
n = end_node
path = []
if n is not None:
while n.parent is not None:
path.insert(0, n)
n = n.parent
for n in path:
x, y = int(n.pos.real), int(n.pos.imag)
grid[y][x] = "O"
print(f"Pos {x},{y} Direction {DIRECTIONS_RV[n.direction]} Cost {n.cost}")
print_grid(grid)
input()
print(chr(27) + "[2J") # clear terminal
if __name__ == "__main__":
import sys
infile = sys.argv[1] if 1 < len(sys.argv) else "example.txt"
with open(infile) as f:
grid = [list(l.rstrip()) for l in f.readlines()]
main(grid)