diff --git a/2024/day16/day16.py b/2024/day16/day16.py new file mode 100644 index 0000000..bf4a781 --- /dev/null +++ b/2024/day16/day16.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass +from heapq import heappop, heappush + + +DIRECTIONS = { + ">": 1 + 0j, # EAST + "v": 0 + 1j, # SOUTH + "<": -1 + 0j, # WEST + "^": 0 - 1j, # NORTH +} + +DIRECTIONS_RV = {v: k for k,v in DIRECTIONS.items()} + + +@dataclass(frozen=True) +class Node: + pos: complex + direction: complex + cost: int = 0 + parent: "Node" = None + + def __lt__(self, other): + return self.cost < other.cost + + +def find_start_and_goal(grid): + start, goal = None, None + for y, row in enumerate(grid): + for x, _ in enumerate(row): + if grid[y][x] == "S": + start = complex(x, y) + elif grid[y][x] == "E": + goal = complex(x, y) + return start, goal + + +def get_pos(grid, pos): + x, y = int(pos.real), int(pos.imag) + if 0 <= x < len(grid[0]) and 0 <= y < len(grid): + return grid[y][x] + return None + + +def search(grid, start_node, end_pos): + """ + Returns the shortest path between start and end using Dijkstra's algorithm + """ + queue = [start_node] + visited = set() + best_costs = {} + while queue != []: + node = heappop(queue) + visited.add(node) + + if node.pos == end_pos: + return node + + if node.cost > best_costs.get((node.pos, node.direction), 99999999): # already found a better path to this pos + continue + + best_costs[(node.pos, node.direction)] = node.cost + + # visit each neighbor + # go in the same direction + n1 = Node(node.pos + node.direction, node.direction, node.cost + 1, node) + if get_pos(grid, n1.pos) != "#" and n1 not in visited: + heappush(queue, n1) + # turn clockwise + turned = node.direction * 1j + n2 = Node(node.pos + turned, turned, node.cost + 1000 + 1, node) + if get_pos(grid, n2.pos) != "#" and n2 not in visited: + heappush(queue, n2) + # turn counterclockwise + turned = node.direction * -1j + n3 = Node(node.pos + turned, turned, node.cost + 1000 + 1, node) + if get_pos(grid, n3.pos) != "#" and n3 not in visited: + heappush(queue, n3) + + return None + + +def print_grid(grid): + for row in grid: + print("".join(row)) + + +def main(grid, debug=False): + start, goal = find_start_and_goal(grid) + direction = 1 + 0j # initial direction is east + end_node = search(grid, Node(start, direction), goal) + total_cost = end_node.cost + print("Part 1: ", total_cost) + if debug: + # compute path + n = end_node + path = [] + if n is not None: + while n.parent is not None: + path.insert(0, n) + n = n.parent + + for n in path: + x, y = int(n.pos.real), int(n.pos.imag) + grid[y][x] = "O" + print(f"Pos {x},{y} Direction {DIRECTIONS_RV[n.direction]} Cost {n.cost}") + print_grid(grid) + input() + print(chr(27) + "[2J") # clear terminal + + +if __name__ == "__main__": + import sys + infile = sys.argv[1] if 1 < len(sys.argv) else "example.txt" + with open(infile) as f: + grid = [list(l.rstrip()) for l in f.readlines()] + main(grid) +