célesta une octave plus haute.
[minwii.git] / src / pgu / algo.py
1 """Some handy algorithms for use in games, etc.
2
3 <p>please note that this file is alpha, and is subject to modification in
4 future versions of pgu!</p>
5 """
6
7 # The manhattan distance metric
8 def manhattan_dist(a,b):
9 return abs(a[0]-b[0]) + abs(a[1]-b[1])
10
11 class node:
12 def __init__(self, prev, pos, dest, dist):
13 self.prev,self.pos,self.dest = prev,pos,dest
14 if self.prev == None: self.g = 0
15 else: self.g = self.prev.g + 1
16 self.h = dist(pos,dest)
17 self.f = self.g+self.h
18
19
20 def astar(start,end,layer,dist=manhattan_dist):
21 """uses the a* algorithm to find a path
22
23 <pre>astar(start,end,layer,dist): return [list of positions]</pre>
24
25 <dl>
26 <dt>start<dd>start position
27 <dt>end<dd>end position
28 <dt>layer<dd>a grid where zero cells are open and non-zero cells are walls
29 <dt>dist<dd>a distance function dist(a,b) - manhattan distance is used by default
30 </dl>
31
32 <p>returns a list of positions from start to end</p>
33 """
34
35 w,h = len(layer[0]),len(layer)
36 if start[0] < 0 or start[1] < 0 or start[0] >= w or start[1] >= h:
37 return [] #start outside of layer
38 if end[0] < 0 or end[1] < 0 or end[0] >= w or end[1] >= h:
39 return [] #end outside of layer
40
41 if layer[start[1]][start[0]]:
42 return [] #start is blocked
43 if layer[end[1]][end[0]]:
44 return [] #end is blocked
45
46 opens = []
47 open = {}
48 closed = {}
49 cur = node(None, start, end, dist)
50 open[cur.pos] = cur
51 opens.append(cur)
52 while len(open):
53 cur = opens.pop(0)
54 if cur.pos not in open: continue
55 del open[cur.pos]
56 closed[cur.pos] = cur
57 if cur.pos == end: break
58 for dx,dy in [(0,-1),(1,0),(0,1),(-1,0)]:#(-1,-1),(1,-1),(-1,1),(1,1)]:
59 pos = cur.pos[0]+dx,cur.pos[1]+dy
60 # Check if the point lies in the grid
61 if (pos[0] < 0 or pos[1] < 0 or
62 pos[0] >= w or pos[1] >= h or
63 layer[pos[0]][pos[1]]):
64 continue
65 #check for blocks of diagonals
66 if layer[cur.pos[1]+dy][cur.pos[0]]: continue
67 if layer[cur.pos[1]][cur.pos[0]+dx]: continue
68 new = node(cur, pos, end, dist)
69 if pos in open and new.f >= open[pos].f: continue
70 if pos in closed and new.f >= closed[pos].f: continue
71 if pos in open: del open[pos]
72 if pos in closed: del closed[pos]
73 open[pos] = new
74 lo = 0
75 hi = len(opens)
76 while lo < hi:
77 mid = (lo+hi)/2
78 if new.f < opens[mid].f: hi = mid
79 else: lo = mid + 1
80 opens.insert(lo,new)
81
82 if cur.pos != end:
83 return []
84
85 path = []
86 while cur.prev != None:
87 path.append(cur.pos)
88 cur = cur.prev
89 path.reverse()
90 return path
91
92
93 def getline(a,b):
94 """returns a path of points from a to b
95
96 <pre>getline(a,b): return [list of points]</pre>
97
98 <dl>
99 <dt>a<dd>starting point
100 <dt>b<dd>ending point
101 </dl>
102
103 <p>returns a list of points from a to b</p>
104 """
105
106 path = []
107
108 x1,y1 = a
109 x2,y2 = b
110 dx,dy = abs(x2-x1),abs(y2-y1)
111
112 if x2 >= x1: xi1,xi2 = 1,1
113 else: xi1,xi2 = -1,-1
114
115 if y2 >= y1: yi1,yi2 = 1,1
116 else: yi1,yi2 = -1,-1
117
118 if dx >= dy:
119 xi1,yi2 = 0,0
120 d = dx
121 n = dx/2
122 a = dy
123 p = dx
124 else:
125 xi2,yi1 = 0,0
126 d = dy
127 n = dy/2
128 a = dx
129 p = dy
130
131 x,y = x1,y1
132 c = 0
133 while c <= p:
134 path.append((x,y))
135 n += a
136 if n > d:
137 n -= d
138 x += xi1
139 y += yi1
140 x += xi2
141 y += yi2
142 c += 1
143 return path