diff --git a/csp/agents.py b/csp/agents.py new file mode 100644 index 00000000..8ed3fa1f --- /dev/null +++ b/csp/agents.py @@ -0,0 +1,533 @@ +"""Implement Agents and Environments (Chapters 1-2). + +The class hierarchies are as follows: + +Object ## A physical object that can exist in an environment + Agent + Wumpus + RandomAgent + ReflexVacuumAgent + ... + Dirt + Wall + ... + +Environment ## An environment holds objects, runs simulations + XYEnvironment + VacuumEnvironment + WumpusEnvironment + +EnvFrame ## A graphical representation of the Environment + +""" + +from utils import * +import random, copy + +#______________________________________________________________________________ + +class Object: + """This represents any physical object that can appear in an Environment. + You subclass Object to get the objects you want. Each object can have a + .__name__ slot (used for output only).""" + def __repr__(self): + return '<%s>' % getattr(self, '__name__', self.__class__.__name__) + + def is_alive(self): + """Objects that are 'alive' should return true.""" + return hasattr(self, 'alive') and self.alive + + def display(self, canvas, x, y, width, height): + """Display an image of this Object on the canvas.""" + pass + +class Agent(Object): + """An Agent is a subclass of Object with one required slot, + .program, which should hold a function that takes one argument, the + percept, and returns an action. (What counts as a percept or action + will depend on the specific environment in which the agent exists.) + Note that 'program' is a slot, not a method. If it were a method, + then the program could 'cheat' and look at aspects of the agent. + It's not supposed to do that: the program can only look at the + percepts. An agent program that needs a model of the world (and of + the agent itself) will have to build and maintain its own model. + There is an optional slots, .performance, which is a number giving + the performance measure of the agent in its environment.""" + + def __init__(self): + def program(percept): + return raw_input('Percept=%s; action? ' % percept) + self.program = program + self.alive = True + +def TraceAgent(agent): + """Wrap the agent's program to print its input and output. This will let + you see what the agent is doing in the environment.""" + old_program = agent.program + def new_program(percept): + action = old_program(percept) + print '%s perceives %s and does %s' % (agent, percept, action) + return action + agent.program = new_program + return agent + +#______________________________________________________________________________ + +class TableDrivenAgent(Agent): + """This agent selects an action based on the percept sequence. + It is practical only for tiny domains. + To customize it you provide a table to the constructor. [Fig. 2.7]""" + + def __init__(self, table): + "Supply as table a dictionary of all {percept_sequence:action} pairs." + ## The agent program could in principle be a function, but because + ## it needs to store state, we make it a callable instance of a class. + Agent.__init__(self) + percepts = [] + def program(percept): + percepts.append(percept) + action = table.get(tuple(percepts)) + return action + self.program = program + + +class RandomAgent(Agent): + "An agent that chooses an action at random, ignoring all percepts." + def __init__(self, actions): + Agent.__init__(self) + self.program = lambda percept: random.choice(actions) + + +#______________________________________________________________________________ + +loc_A, loc_B = (0, 0), (1, 0) # The two locations for the Vacuum world + +class ReflexVacuumAgent(Agent): + "A reflex agent for the two-state vacuum environment. [Fig. 2.8]" + + def __init__(self): + Agent.__init__(self) + def program((location, status)): + if status == 'Dirty': return 'Suck' + elif location == loc_A: return 'Right' + elif location == loc_B: return 'Left' + self.program = program + + +def RandomVacuumAgent(): + "Randomly choose one of the actions from the vaccum environment." + return RandomAgent(['Right', 'Left', 'Suck', 'NoOp']) + + +def TableDrivenVacuumAgent(): + "[Fig. 2.3]" + table = {((loc_A, 'Clean'),): 'Right', + ((loc_A, 'Dirty'),): 'Suck', + ((loc_B, 'Clean'),): 'Left', + ((loc_B, 'Dirty'),): 'Suck', + ((loc_A, 'Clean'), (loc_A, 'Clean')): 'Right', + ((loc_A, 'Clean'), (loc_A, 'Dirty')): 'Suck', + # ... + ((loc_A, 'Clean'), (loc_A, 'Clean'), (loc_A, 'Clean')): 'Right', + ((loc_A, 'Clean'), (loc_A, 'Clean'), (loc_A, 'Dirty')): 'Suck', + # ... + } + return TableDrivenAgent(table) + + +class ModelBasedVacuumAgent(Agent): + "An agent that keeps track of what locations are clean or dirty." + def __init__(self): + Agent.__init__(self) + model = {loc_A: None, loc_B: None} + def program((location, status)): + "Same as ReflexVacuumAgent, except if everything is clean, do NoOp" + model[location] = status ## Update the model here + if model[loc_A] == model[loc_B] == 'Clean': return 'NoOp' + elif status == 'Dirty': return 'Suck' + elif location == loc_A: return 'Right' + elif location == loc_B: return 'Left' + self.program = program + +#______________________________________________________________________________ + +class Environment: + """Abstract class representing an Environment. 'Real' Environment classes + inherit from this. Your Environment will typically need to implement: + percept: Define the percept that an agent sees. + execute_action: Define the effects of executing an action. + Also update the agent.performance slot. + The environment keeps a list of .objects and .agents (which is a subset + of .objects). Each agent has a .performance slot, initialized to 0. + Each object has a .location slot, even though some environments may not + need this.""" + + def __init__(self,): + self.objects = []; self.agents = [] + + object_classes = [] ## List of classes that can go into environment + + def percept(self, agent): + "Return the percept that the agent sees at this point. Override this." + abstract + + def execute_action(self, agent, action): + "Change the world to reflect this action. Override this." + abstract + + def default_location(self, object): + "Default location to place a new object with unspecified location." + return None + + def exogenous_change(self): + "If there is spontaneous change in the world, override this." + pass + + def is_done(self): + "By default, we're done when we can't find a live agent." + for agent in self.agents: + if agent.is_alive(): return False + return True + + def step(self): + """Run the environment for one time step. If the + actions and exogenous changes are independent, this method will + do. If there are interactions between them, you'll need to + override this method.""" + if not self.is_done(): + actions = [agent.program(self.percept(agent)) + for agent in self.agents] + for (agent, action) in zip(self.agents, actions): + self.execute_action(agent, action) + self.exogenous_change() + + def run(self, steps=1000): + """Run the Environment for given number of time steps.""" + for step in range(steps): + if self.is_done(): return + self.step() + + def add_object(self, object, location=None): + """Add an object to the environment, setting its location. Also keep + track of objects that are agents. Shouldn't need to override this.""" + object.location = location or self.default_location(object) + self.objects.append(object) + if isinstance(object, Agent): + object.performance = 0 + self.agents.append(object) + return self + + +class XYEnvironment(Environment): + """This class is for environments on a 2D plane, with locations + labelled by (x, y) points, either discrete or continuous. Agents + perceive objects within a radius. Each agent in the environment + has a .location slot which should be a location such as (0, 1), + and a .holding slot, which should be a list of objects that are + held """ + + def __init__(self, width=10, height=10): + update(self, objects=[], agents=[], width=width, height=height) + + def objects_at(self, location): + "Return all objects exactly at a given location." + return [obj for obj in self.objects if obj.location == location] + + def objects_near(self, location, radius): + "Return all objects within radius of location." + radius2 = radius * radius + return [obj for obj in self.objects + if distance2(location, obj.location) <= radius2] + + def percept(self, agent): + "By default, agent perceives objects within radius r." + return [self.object_percept(obj, agent) + for obj in self.objects_near(agent)] + + def execute_action(self, agent, action): + if action == 'TurnRight': + agent.heading = turn_heading(agent.heading, -1) + elif action == 'TurnLeft': + agent.heading = turn_heading(agent.heading, +1) + elif action == 'Forward': + self.move_to(agent, vector_add(agent.heading, agent.location)) + elif action == 'Grab': + objs = [obj for obj in self.objects_at(agent.location) + if obj.is_grabable(agent)] + if objs: + agent.holding.append(objs[0]) + elif action == 'Release': + if agent.holding: + agent.holding.pop() + agent.bump = False + + def object_percept(self, obj, agent): #??? Should go to object? + "Return the percept for this object." + return obj.__class__.__name__ + + def default_location(self, object): + return (random.choice(self.width), random.choice(self.height)) + + def move_to(object, destination): + "Move an object to a new location." + + def add_object(self, object, location=(1, 1)): + Environment.add_object(self, object, location) + object.holding = [] + object.held = None + self.objects.append(object) + + def add_walls(self): + "Put walls around the entire perimeter of the grid." + for x in range(self.width): + self.add_object(Wall(), (x, 0)) + self.add_object(Wall(), (x, self.height-1)) + for y in range(self.height): + self.add_object(Wall(), (0, y)) + self.add_object(Wall(), (self.width-1, y)) + +def turn_heading(self, heading, inc, + headings=[(1, 0), (0, 1), (-1, 0), (0, -1)]): + "Return the heading to the left (inc=+1) or right (inc=-1) in headings." + return headings[(headings.index(heading) + inc) % len(headings)] + +#______________________________________________________________________________ +## Vacuum environment + +class TrivialVacuumEnvironment(Environment): + """This environment has two locations, A and B. Each can be Dirty or Clean. + The agent perceives its location and the location's status. This serves as + an example of how to implement a simple Environment.""" + + def __init__(self): + Environment.__init__(self) + self.status = {loc_A:random.choice(['Clean', 'Dirty']), + loc_B:random.choice(['Clean', 'Dirty'])} + + def percept(self, agent): + "Returns the agent's location, and the location status (Dirty/Clean)." + return (agent.location, self.status[agent.location]) + + def execute_action(self, agent, action): + """Change agent's location and/or location's status; track performance. + Score 10 for each dirt cleaned; -1 for each move.""" + if action == 'Right': + agent.location = loc_B + agent.performance -= 1 + elif action == 'Left': + agent.location = loc_A + agent.performance -= 1 + elif action == 'Suck': + if self.status[agent.location] == 'Dirty': + agent.performance += 10 + self.status[agent.location] = 'Clean' + + def default_location(self, object): + "Agents start in either location at random." + return random.choice([loc_A, loc_B]) + +class Dirt(Object): pass +class Wall(Object): pass + +class VacuumEnvironment(XYEnvironment): + """The environment of [Ex. 2.12]. Agent perceives dirty or clean, + and bump (into obstacle) or not; 2D discrete world of unknown size; + performance measure is 100 for each dirt cleaned, and -1 for + each turn taken.""" + def __init__(self, width=10, height=10): + XYEnvironment.__init__(self, width, height) + self.add_walls() + + object_classes = [Wall, Dirt, ReflexVacuumAgent, RandomVacuumAgent, + TableDrivenVacuumAgent, ModelBasedVacuumAgent] + + def percept(self, agent): + """The percept is a tuple of ('Dirty' or 'Clean', 'Bump' or 'None'). + Unlike the TrivialVacuumEnvironment, location is NOT perceived.""" + status = if_(self.find_at(Dirt, agent.location), 'Dirty', 'Clean') + bump = if_(agent.bump, 'Bump', 'None') + return (status, bump) + + def execute_action(self, agent, action): + if action == 'Suck': + if self.find_at(Dirt, agent.location): + agent.performance += 100 + agent.performance -= 1 + XYEnvironment.execute_action(self, agent, action) + +#______________________________________________________________________________ + +class SimpleReflexAgent(Agent): + """This agent takes action based solely on the percept. [Fig. 2.13]""" + + def __init__(self, rules, interpret_input): + Agent.__init__(self) + def program(percept): + state = interpret_input(percept) + rule = rule_match(state, rules) + action = rule.action + return action + self.program = program + +class ReflexAgentWithState(Agent): + """This agent takes action based on the percept and state. [Fig. 2.16]""" + + def __init__(self, rules, udpate_state): + Agent.__init__(self) + state, action = None, None + def program(percept): + state = update_state(state, action, percept) + rule = rule_match(state, rules) + action = rule.action + return action + self.program = program + +#______________________________________________________________________________ +## The Wumpus World + +class Gold(Object): pass +class Pit(Object): pass +class Arrow(Object): pass +class Wumpus(Agent): pass +class Explorer(Agent): pass + +class WumpusEnvironment(XYEnvironment): + object_classes = [Wall, Gold, Pit, Arrow, Wumpus, Explorer] + def __init__(self, width=10, height=10): + XYEnvironment.__init__(self, width, height) + self.add_walls() + ## Needs a lot of work ... + + +#______________________________________________________________________________ + +def compare_agents(EnvFactory, AgentFactories, n=10, steps=1000): + """See how well each of several agents do in n instances of an environment. + Pass in a factory (constructor) for environments, and several for agents. + Create n instances of the environment, and run each agent in copies of + each one for steps. Return a list of (agent, average-score) tuples.""" + envs = [EnvFactory() for i in range(n)] + return [(A, test_agent(A, steps, copy.deepcopy(envs))) + for A in AgentFactories] + +def test_agent(AgentFactory, steps, envs): + "Return the mean score of running an agent in each of the envs, for steps" + total = 0 + for env in envs: + agent = AgentFactory() + env.add_object(agent) + env.run(steps) + total += agent.performance + return float(total)/len(envs) + +#______________________________________________________________________________ + +_docex = """ +a = ReflexVacuumAgent() +a.program +a.program((loc_A, 'Clean')) ==> 'Right' +a.program((loc_B, 'Clean')) ==> 'Left' +a.program((loc_A, 'Dirty')) ==> 'Suck' +a.program((loc_A, 'Dirty')) ==> 'Suck' + +e = TrivialVacuumEnvironment() +e.add_object(TraceAgent(ModelBasedVacuumAgent())) +e.run(5) + +## Environments, and some agents, are randomized, so the best we can +## give is a range of expected scores. If this test fails, it does +## not necessarily mean something is wrong. +envs = [TrivialVacuumEnvironment() for i in range(100)] +def testv(A): return test_agent(A, 4, copy.deepcopy(envs)) +testv(ModelBasedVacuumAgent) +(7 < _ < 11) ==> True +testv(ReflexVacuumAgent) +(5 < _ < 9) ==> True +testv(TableDrivenVacuumAgent) +(2 < _ < 6) ==> True +testv(RandomVacuumAgent) +(0.5 < _ < 3) ==> True +""" + +#______________________________________________________________________________ +# GUI - Graphical User Interface for Environments +# If you do not have Tkinter installed, either get a new installation of Python +# (Tkinter is standard in all new releases), or delete the rest of this file +# and muddle through without a GUI. + +''' +import Tkinter as tk + +class EnvFrame(tk.Frame): + def __init__(self, env, title='AIMA GUI', cellwidth=50, n=10): + update(self, cellwidth = cellwidth, running=False, delay=1.0) + self.n = n + self.running = 0 + self.delay = 1.0 + self.env = env + tk.Frame.__init__(self, None, width=(cellwidth+2)*n, height=(cellwidth+2)*n) + #self.title(title) + # Toolbar + toolbar = tk.Frame(self, relief='raised', bd=2) + toolbar.pack(side='top', fill='x') + for txt, cmd in [('Step >', self.env.step), ('Run >>', self.run), + ('Stop [ ]', self.stop)]: + tk.Button(toolbar, text=txt, command=cmd).pack(side='left') + tk.Label(toolbar, text='Delay').pack(side='left') + scale = tk.Scale(toolbar, orient='h', from_=0.0, to=10, resolution=0.5, + command=lambda d: setattr(self, 'delay', d)) + scale.set(self.delay) + scale.pack(side='left') + # Canvas for drawing on + self.canvas = tk.Canvas(self, width=(cellwidth+1)*n, + height=(cellwidth+1)*n, background="white") + self.canvas.bind('', self.left) ## What should this do? + self.canvas.bind('', self.edit_objects) + self.canvas.bind('', self.add_object) + if cellwidth: + c = self.canvas + for i in range(1, n+1): + c.create_line(0, i*cellwidth, n*cellwidth, i*cellwidth) + c.create_line(i*cellwidth, 0, i*cellwidth, n*cellwidth) + c.pack(expand=1, fill='both') + self.pack() + + + def background_run(self): + if self.running: + self.env.step() + ms = int(1000 * max(float(self.delay), 0.5)) + self.after(ms, self.background_run) + + def run(self): + print 'run' + self.running = 1 + self.background_run() + + def stop(self): + print 'stop' + self.running = 0 + + def left(self, event): + print 'left at ', event.x/50, event.y/50 + + def edit_objects(self, event): + """Choose an object within radius and edit its fields.""" + pass + + def add_object(self, event): + ## This is supposed to pop up a menu of Object classes; you choose the one + ## You want to put in this square. Not working yet. + menu = tk.Menu(self, title='Edit (%d, %d)' % (event.x/50, event.y/50)) + for (txt, cmd) in [('Wumpus', self.run), ('Pit', self.run)]: + menu.add_command(label=txt, command=cmd) + menu.tk_popup(event.x + self.winfo_rootx(), + event.y + self.winfo_rooty()) + + #image=PhotoImage(file=r"C:\Documents and Settings\pnorvig\Desktop\wumpus.gif") + #self.images = [] + #self.images.append(image) + #c.create_image(200,200,anchor=NW,image=image) + +#v = VacuumEnvironment(); w = EnvFrame(v); +''' diff --git a/csp/csp.py b/csp/csp.py new file mode 100644 index 00000000..9347599a --- /dev/null +++ b/csp/csp.py @@ -0,0 +1,450 @@ +"""CSP (Constraint Satisfaction Problems) problems and solvers. (Chapter 5).""" + +from __future__ import generators +from utils import * +import search +import types + +class CSP(search.Problem): + """This class describes finite-domain Constraint Satisfaction Problems. + A CSP is specified by the following three inputs: + vars A list of variables; each is atomic (e.g. int or string). + domains A dict of {var:[possible_value, ...]} entries. + neighbors A dict of {var:[var,...]} that for each variable lists + the other variables that participate in constraints. + constraints A function f(A, a, B, b) that returns true if neighbors + A, B satisfy the constraint when they have values A=a, B=b + In the textbook and in most mathematical definitions, the + constraints are specified as explicit pairs of allowable values, + but the formulation here is easier to express and more compact for + most cases. (For example, the n-Queens problem can be represented + in O(n) space using this notation, instead of O(N^4) for the + explicit representation.) In terms of describing the CSP as a + problem, that's all there is. + + However, the class also supports data structures and methods that help you + solve CSPs by calling a search function on the CSP. Methods and slots are + as follows, where the argument 'a' represents an assignment, which is a + dict of {var:val} entries: + assign(var, val, a) Assign a[var] = val; do other bookkeeping + unassign(var, a) Do del a[var], plus other bookkeeping + nconflicts(var, val, a) Return the number of other variables that + conflict with var=val + curr_domains[var] Slot: remaining consistent values for var + Used by constraint propagation routines. + The following methods are used only by graph_search and tree_search: + succ() Return a list of (action, state) pairs + goal_test(a) Return true if all constraints satisfied + The following are just for debugging purposes: + nassigns Slot: tracks the number of assignments made + display(a) Print a human-readable representation + """ + + def __init__(self, vars, domains, neighbors, constraints): + "Construct a CSP problem. If vars is empty, it becomes domains.keys()." + vars = vars or domains.keys() + update(self, vars=vars, domains=domains, + neighbors=neighbors, constraints=constraints, + initial={}, curr_domains=None, pruned=None, nassigns=0) + + def assign(self, var, val, assignment): + """Add {var: val} to assignment; Discard the old value if any. + Do bookkeeping for curr_domains and nassigns.""" + self.nassigns += 1 + assignment[var] = val + if self.curr_domains: + if self.fc: + self.forward_check(var, val, assignment) + if self.mac: + AC3(self, [(Xk, var) for Xk in self.neighbors[var]]) + + def unassign(self, var, assignment): + """Remove {var: val} from assignment; that is backtrack. + DO NOT call this if you are changing a variable to a new value; + just call assign for that.""" + if var in assignment: + # Reset the curr_domain to be the full original domain + if self.curr_domains: + self.curr_domains[var] = self.domains[var][:] + del assignment[var] + + def nconflicts(self, var, val, assignment): + "Return the number of conflicts var=val has with other variables." + # Subclasses may implement this more efficiently + def conflict(var2): + val2 = assignment.get(var2, None) + return val2 != None and not self.constraints(var, val, var2, val2) + return count_if(conflict, self.neighbors[var]) + + def forward_check(self, var, val, assignment): + "Do forward checking (current domain reduction) for this assignment." + if self.curr_domains: + # Restore prunings from previous value of var + for (B, b) in self.pruned[var]: + self.curr_domains[B].append(b) + self.pruned[var] = [] + # Prune any other B=b assignement that conflict with var=val + for B in self.neighbors[var]: + if B not in assignment: + for b in self.curr_domains[B][:]: + if not self.constraints(var, val, B, b): + self.curr_domains[B].remove(b) + self.pruned[var].append((B, b)) + + def display(self, assignment): + "Show a human-readable representation of the CSP." + # Subclasses can print in a prettier way, or display with a GUI + print 'CSP:', self, 'with assignment:', assignment + + ## These methods are for the tree and graph search interface: + + def succ(self, assignment): + "Return a list of (action, state) pairs." + if len(assignment) == len(self.vars): + return [] + else: + var = find_if(lambda v: v not in assignment, self.vars) + result = [] + for val in self.domains[var]: + if self.nconflicts(self, var, val, assignment) == 0: + a = assignment.copy; a[var] = val + result.append(((var, val), a)) + return result + + def goal_test(self, assignment): + "The goal is to assign all vars, with all constraints satisfied." + return (len(assignment) == len(self.vars) and + every(lambda var: self.nconflicts(var, assignment[var], + assignment) == 0, + self.vars)) + + ## This is for min_conflicts search + + def conflicted_vars(self, current): + "Return a list of variables in current assignment that are in conflict" + return [var for var in self.vars + if self.nconflicts(var, current[var], current) > 0] + +#______________________________________________________________________________ +# CSP Backtracking Search + +def backtracking_search(csp, mcv=False, lcv=False, fc=False, mac=False): + """Set up to do recursive backtracking search. Allow the following options: + mcv - If true, use Most Constrained Variable Heuristic + lcv - If true, use Least Constraining Value Heuristic + fc - If true, use Forward Checking + mac - If true, use Maintaining Arc Consistency. [Fig. 5.3] + >>> backtracking_search(australia) + {'WA': 'B', 'Q': 'B', 'T': 'B', 'V': 'B', 'SA': 'G', 'NT': 'R', 'NSW': 'R'} + """ + if fc or mac: + csp.curr_domains, csp.pruned = {}, {} + for v in csp.vars: + csp.curr_domains[v] = csp.domains[v][:] + csp.pruned[v] = [] + update(csp, mcv=mcv, lcv=lcv, fc=fc, mac=mac) + return recursive_backtracking({}, csp) + +def recursive_backtracking(assignment, csp): + """Search for a consistent assignment for the csp. + Each recursive call chooses a variable, and considers values for it.""" + if len(assignment) == len(csp.vars): + return assignment + var = select_unassigned_variable(assignment, csp) + for val in order_domain_values(var, assignment, csp): + if csp.fc or csp.nconflicts(var, val, assignment) == 0: + csp.assign(var, val, assignment) + result = recursive_backtracking(assignment, csp) + if result is not None: + return result + csp.unassign(var, assignment) + return None + +def select_unassigned_variable(assignment, csp): + "Select the variable to work on next. Find" + if csp.mcv: # Most Constrained Variable + unassigned = [v for v in csp.vars if v not in assignment] + return argmin_random_tie(unassigned, + lambda var: -num_legal_values(csp, var, assignment)) + else: # First unassigned variable + for v in csp.vars: + if v not in assignment: + return v + +def order_domain_values(var, assignment, csp): + "Decide what order to consider the domain variables." + if csp.curr_domains: + domain = csp.curr_domains[var] + else: + domain = csp.domains[var][:] + if csp.lcv: + # If LCV is specified, consider values with fewer conflicts first + key = lambda val: csp.nconflicts(var, val, assignment) + domain.sort(lambda(x,y): cmp(key(x), key(y))) + while domain: + yield domain.pop() + +def num_legal_values(csp, var, assignment): + if csp.curr_domains: + return len(csp.curr_domains[var]) + else: + return count_if(lambda val: csp.nconflicts(var, val, assignment) == 0, + csp.domains[var]) + +#______________________________________________________________________________ +# Constraint Propagation with AC-3 + +def AC3(csp, queue=None): + """[Fig. 5.7]""" + if queue == None: + queue = [(Xi, Xk) for Xi in csp.vars for Xk in csp.neighbors[Xi]] + while queue: + (Xi, Xj) = queue.pop() + if remove_inconsistent_values(csp, Xi, Xj): + for Xk in csp.neighbors[Xi]: + queue.append((Xk, Xi)) + +def remove_inconsistent_values(csp, Xi, Xj): + "Return true if we remove a value." + removed = False + for x in csp.curr_domains[Xi][:]: + # If Xi=x conflicts with Xj=y for every possible y, eliminate Xi=x + if every(lambda y: not csp.constraints(Xi, x, Xj, y), + csp.curr_domains[Xj]): + csp.curr_domains[Xi].remove(x) + removed = True + return removed + +#______________________________________________________________________________ +# Min-conflicts hillclimbing search for CSPs + +def min_conflicts(csp, max_steps=1000000): + """Solve a CSP by stochastic hillclimbing on the number of conflicts.""" + # Generate a complete assignement for all vars (probably with conflicts) + current = {}; csp.current = current + for var in csp.vars: + val = min_conflicts_value(csp, var, current) + csp.assign(var, val, current) + # Now repeapedly choose a random conflicted variable and change it + for i in range(max_steps): + conflicted = csp.conflicted_vars(current) + if not conflicted: + return current + var = random.choice(conflicted) + val = min_conflicts_value(csp, var, current) + csp.assign(var, val, current) + return None + +def min_conflicts_value(csp, var, current): + """Return the value that will give var the least number of conflicts. + If there is a tie, choose at random.""" + return argmin_random_tie(csp.domains[var], + lambda val: csp.nconflicts(var, val, current)) + +#______________________________________________________________________________ +# Map-Coloring Problems + +class UniversalDict: + """A universal dict maps any key to the same value. We use it here + as the domains dict for CSPs in which all vars have the same domain. + >>> d = UniversalDict(42) + >>> d['life'] + 42 + """ + def __init__(self, value): self.value = value + def __getitem__(self, key): return self.value + def __repr__(self): return '{Any: %r}' % self.value + +def different_values_constraint(A, a, B, b): + "A constraint saying two neighboring variables must differ in value." + return a != b + +def MapColoringCSP(colors, neighbors): + """Make a CSP for the problem of coloring a map with different colors + for any two adjacent regions. Arguments are a list of colors, and a + dict of {region: [neighbor,...]} entries. This dict may also be + specified as a string of the form defined by parse_neighbors""" + + if isinstance(neighbors, str): + neighbors = parse_neighbors(neighbors) + return CSP(neighbors.keys(), UniversalDict(colors), neighbors, + different_values_constraint) + +def parse_neighbors(neighbors, vars=[]): + """Convert a string of the form 'X: Y Z; Y: Z' into a dict mapping + regions to neighbors. The syntax is a region name followed by a ':' + followed by zero or more region names, followed by ';', repeated for + each region name. If you say 'X: Y' you don't need 'Y: X'. + >>> parse_neighbors('X: Y Z; Y: Z') + {'Y': ['X', 'Z'], 'X': ['Y', 'Z'], 'Z': ['X', 'Y']} + """ + dict = DefaultDict([]) + for var in vars: + dict[var] = [] + specs = [spec.split(':') for spec in neighbors.split(';')] + for (A, Aneighbors) in specs: + A = A.strip(); + dict.setdefault(A, []) + for B in Aneighbors.split(): + dict[A].append(B) + dict[B].append(A) + return dict + +australia = MapColoringCSP(list('RGB'), + 'SA: WA NT Q NSW V; NT: WA Q; NSW: Q V; T: ') + +usa = MapColoringCSP(list('RGBY'), + """WA: OR ID; OR: ID NV CA; CA: NV AZ; NV: ID UT AZ; ID: MT WY UT; + UT: WY CO AZ; MT: ND SD WY; WY: SD NE CO; CO: NE KA OK NM; NM: OK TX; + ND: MN SD; SD: MN IA NE; NE: IA MO KA; KA: MO OK; OK: MO AR TX; + TX: AR LA; MN: WI IA; IA: WI IL MO; MO: IL KY TN AR; AR: MS TN LA; + LA: MS; WI: MI IL; IL: IN; IN: KY; MS: TN AL; AL: TN GA FL; MI: OH; + OH: PA WV KY; KY: WV VA TN; TN: VA NC GA; GA: NC SC FL; + PA: NY NJ DE MD WV; WV: MD VA; VA: MD DC NC; NC: SC; NY: VT MA CA NJ; + NJ: DE; DE: MD; MD: DC; VT: NH MA; MA: NH RI CT; CT: RI; ME: NH; + HI: ; AK: """) +#______________________________________________________________________________ +# n-Queens Problem + +def queen_constraint(A, a, B, b): + """Constraint is satisfied (true) if A, B are really the same variable, + or if they are not in the same row, down diagonal, or up diagonal.""" + return A == B or (a != b and A + a != B + b and A - a != B - b) + +class NQueensCSP(CSP): + """Make a CSP for the nQueens problem for search with min_conflicts. + Suitable for large n, it uses only data structures of size O(n). + Think of placing queens one per column, from left to right. + That means position (x, y) represents (var, val) in the CSP. + The main structures are three arrays to count queens that could conflict: + rows[i] Number of queens in the ith row (i.e val == i) + downs[i] Number of queens in the \ diagonal + such that their (x, y) coordinates sum to i + ups[i] Number of queens in the / diagonal + such that their (x, y) coordinates have x-y+n-1 = i + We increment/decrement these counts each time a queen is placed/moved from + a row/diagonal. So moving is O(1), as is nconflicts. But choosing + a variable, and a best value for the variable, are each O(n). + If you want, you can keep track of conflicted vars, then variable + selection will also be O(1). + >>> len(backtracking_search(NQueensCSP(8))) + 8 + >>> len(min_conflicts(NQueensCSP(8))) + 8 + """ + def __init__(self, n): + """Initialize data structures for n Queens.""" + CSP.__init__(self, range(n), UniversalDict(range(n)), + UniversalDict(range(n)), queen_constraint) + update(self, rows=[0]*n, ups=[0]*(2*n - 1), downs=[0]*(2*n - 1)) + + def nconflicts(self, var, val, assignment): + """The number of conflicts, as recorded with each assignment. + Count conflicts in row and in up, down diagonals. If there + is a queen there, it can't conflict with itself, so subtract 3.""" + n = len(self.vars) + c = self.rows[val] + self.downs[var+val] + self.ups[var-val+n-1] + if assignment.get(var, None) == val: + c -= 3 + return c + + def assign(self, var, val, assignment): + "Assign var, and keep track of conflicts." + oldval = assignment.get(var, None) + if val != oldval: + if oldval is not None: # Remove old val if there was one + self.record_conflict(assignment, var, oldval, -1) + self.record_conflict(assignment, var, val, +1) + CSP.assign(self, var, val, assignment) + + def unassign(self, var, assignment): + "Remove var from assignment (if it is there) and track conflicts." + if var in assignment: + self.record_conflict(assignment, var, assignment[var], -1) + CSP.unassign(self, var, assignment) + + def record_conflict(self, assignment, var, val, delta): + "Record conflicts caused by addition or deletion of a Queen." + n = len(self.vars) + self.rows[val] += delta + self.downs[var + val] += delta + self.ups[var - val + n - 1] += delta + + def display(self, assignment): + "Print the queens and the nconflicts values (for debugging)." + n = len(self.vars) + for val in range(n): + for var in range(n): + if assignment.get(var,'') == val: ch ='Q' + elif (var+val) % 2 == 0: ch = '.' + else: ch = '-' + print ch, + print ' ', + for var in range(n): + if assignment.get(var,'') == val: ch ='*' + else: ch = ' ' + print str(self.nconflicts(var, val, assignment))+ch, + print + +#______________________________________________________________________________ +# The Zebra Puzzle + +def Zebra(): + "Return an instance of the Zebra Puzzle." + Colors = 'Red Yellow Blue Green Ivory'.split() + Pets = 'Dog Fox Snails Horse Zebra'.split() + Drinks = 'OJ Tea Coffee Milk Water'.split() + Countries = 'Englishman Spaniard Norwegian Ukranian Japanese'.split() + Smokes = 'Kools Chesterfields Winston LuckyStrike Parliaments'.split() + vars = Colors + Pets + Drinks + Countries + Smokes + domains = {} + for var in vars: + domains[var] = range(1, 6) + domains['Norwegian'] = [1] + domains['Milk'] = [3] + neighbors = parse_neighbors("""Englishman: Red; + Spaniard: Dog; Kools: Yellow; Chesterfields: Fox; + Norwegian: Blue; Winston: Snails; LuckyStrike: OJ; + Ukranian: Tea; Japanese: Parliaments; Kools: Horse; + Coffee: Green; Green: Ivory""", vars) + for type in [Colors, Pets, Drinks, Countries, Smokes]: + for A in type: + for B in type: + if A != B: + if B not in neighbors[A]: neighbors[A].append(B) + if A not in neighbors[B]: neighbors[B].append(A) + def zebra_constraint(A, a, B, b, recurse=0): + same = (a == b) + next_to = abs(a - b) == 1 + if A == 'Englishman' and B == 'Red': return same + if A == 'Spaniard' and B == 'Dog': return same + if A == 'Chesterfields' and B == 'Fox': return next_to + if A == 'Norwegian' and B == 'Blue': return next_to + if A == 'Kools' and B == 'Yellow': return same + if A == 'Winston' and B == 'Snails': return same + if A == 'LuckyStrike' and B == 'OJ': return same + if A == 'Ukranian' and B == 'Tea': return same + if A == 'Japanese' and B == 'Parliaments': return same + if A == 'Kools' and B == 'Horse': return next_to + if A == 'Coffee' and B == 'Green': return same + if A == 'Green' and B == 'Ivory': return (a - 1) == b + if recurse == 0: return zebra_constraint(B, b, A, a, 1) + if ((A in Colors and B in Colors) or + (A in Pets and B in Pets) or + (A in Drinks and B in Drinks) or + (A in Countries and B in Countries) or + (A in Smokes and B in Smokes)): return not same + raise 'error' + return CSP(vars, domains, neighbors, zebra_constraint) + +def solve_zebra(algorithm=min_conflicts, **args): + z = Zebra() + ans = algorithm(z, **args) + for h in range(1, 6): + print 'House', h, + for (var, val) in ans.items(): + if val == h: print var, + print + return ans['Zebra'], ans['Water'], z.nassigns, ans, + + diff --git a/csp/search.py b/csp/search.py new file mode 100644 index 00000000..cb0c07dd --- /dev/null +++ b/csp/search.py @@ -0,0 +1,736 @@ +"""Search (Chapters 3-4) + +The way to use this code is to subclass Problem to create a class of problems, +then create problem instances and solve them with calls to the various search +functions.""" + +from __future__ import generators +from utils import * +import agents +import math, random, sys, time, bisect, string + +#______________________________________________________________________________ + +class Problem: + """The abstract class for a formal problem. You should subclass this and + implement the method successor, and possibly __init__, goal_test, and + path_cost. Then you will create instances of your subclass and solve them + with the various search functions.""" + + def __init__(self, initial, goal=None): + """The constructor specifies the initial state, and possibly a goal + state, if there is a unique goal. Your subclass's constructor can add + other arguments.""" + self.initial = initial; self.goal = goal + + def successor(self, state): + """Given a state, return a sequence of (action, state) pairs reachable + from this state. If there are many successors, consider an iterator + that yields the successors one at a time, rather than building them + all at once. Iterators will work fine within the framework.""" + abstract + + def goal_test(self, state): + """Return True if the state is a goal. The default method compares the + state to self.goal, as specified in the constructor. Implement this + method if checking against a single self.goal is not enough.""" + return state == self.goal + + def path_cost(self, c, state1, action, state2): + """Return the cost of a solution path that arrives at state2 from + state1 via action, assuming cost c to get up to state1. If the problem + is such that the path doesn't matter, this function will only look at + state2. If the path does matter, it will consider c and maybe state1 + and action. The default method costs 1 for every step in the path.""" + return c + 1 + + def value(self): + """For optimization problems, each state has a value. Hill-climbing + and related algorithms try to maximize this value.""" + abstract +#______________________________________________________________________________ + +class Node: + """A node in a search tree. Contains a pointer to the parent (the node + that this is a successor of) and to the actual state for this node. Note + that if a state is arrived at by two paths, then there are two nodes with + the same state. Also includes the action that got us to this state, and + the total path_cost (also known as g) to reach the node. Other functions + may add an f and h value; see best_first_graph_search and astar_search for + an explanation of how the f and h values are handled. You will not need to + subclass this class.""" + + def __init__(self, state, parent=None, action=None, path_cost=0): + "Create a search tree Node, derived from a parent by an action." + update(self, state=state, parent=parent, action=action, + path_cost=path_cost, depth=0) + if parent: + self.depth = parent.depth + 1 + + def __repr__(self): + return "" % (self.state,) + + def path(self): + "Create a list of nodes from the root to this node." + x, result = self, [self] + while x.parent: + result.append(x.parent) + x = x.parent + return result + + def expand(self, problem): + "Return a list of nodes reachable from this node. [Fig. 3.8]" + return [Node(next, self, act, + problem.path_cost(self.path_cost, self.state, act, next)) + for (act, next) in problem.successor(self.state)] + +#______________________________________________________________________________ + +class SimpleProblemSolvingAgent(agents.Agent): + """Abstract framework for problem-solving agent. [Fig. 3.1]""" + def __init__(self): + Agent.__init__(self) + state = [] + seq = [] + + def program(percept): + state = self.update_state(state, percept) + if not seq: + goal = self.formulate_goal(state) + problem = self.formulate_problem(state, goal) + seq = self.search(problem) + action = seq[0] + seq[0:1] = [] + return action + + self.program = program + +#______________________________________________________________________________ +## Uninformed Search algorithms + +def tree_search(problem, fringe): + """Search through the successors of a problem to find a goal. + The argument fringe should be an empty queue. + Don't worry about repeated paths to a state. [Fig. 3.8]""" + fringe.append(Node(problem.initial)) + while fringe: + node = fringe.pop() + if problem.goal_test(node.state): + return node + fringe.extend(node.expand(problem)) + return None + +def breadth_first_tree_search(problem): + "Search the shallowest nodes in the search tree first. [p 74]" + return tree_search(problem, FIFOQueue()) + +def depth_first_tree_search(problem): + "Search the deepest nodes in the search tree first. [p 74]" + return tree_search(problem, Stack()) + +def graph_search(problem, fringe): + """Search through the successors of a problem to find a goal. + The argument fringe should be an empty queue. + If two paths reach a state, only use the best one. [Fig. 3.18]""" + closed = {} + fringe.append(Node(problem.initial)) + while fringe: + node = fringe.pop() + if problem.goal_test(node.state): + return node + if node.state not in closed: + closed[node.state] = True + fringe.extend(node.expand(problem)) + return None + +def breadth_first_graph_search(problem): + "Search the shallowest nodes in the search tree first. [p 74]" + return graph_search(problem, FIFOQueue()) + +def depth_first_graph_search(problem): + "Search the deepest nodes in the search tree first. [p 74]" + return graph_search(problem, Stack()) + +def depth_limited_search(problem, limit=50): + "[Fig. 3.12]" + def recursive_dls(node, problem, limit): + cutoff_occurred = False + if problem.goal_test(node.state): + return node + elif node.depth == limit: + return 'cutoff' + else: + for successor in node.expand(problem): + result = recursive_dls(successor, problem, limit) + if result == 'cutoff': + cutoff_occurred = True + elif result != None: + return result + if cutoff_occurred: + return 'cutoff' + else: + return None + # Body of depth_limited_search: + return recursive_dls(Node(problem.initial), problem, limit) + +def iterative_deepening_search(problem): + "[Fig. 3.13]" + for depth in xrange(sys.maxint): + result = depth_limited_search(problem, depth) + if result is not 'cutoff': + return result + +#______________________________________________________________________________ +# Informed (Heuristic) Search + +def best_first_graph_search(problem, f): + """Search the nodes with the lowest f scores first. + You specify the function f(node) that you want to minimize; for example, + if f is a heuristic estimate to the goal, then we have greedy best + first search; if f is node.depth then we have depth-first search. + There is a subtlety: the line "f = memoize(f, 'f')" means that the f + values will be cached on the nodes as they are computed. So after doing + a best first search you can examine the f values of the path returned.""" + f = memoize(f, 'f') + return graph_search(problem, PriorityQueue(min, f)) + +greedy_best_first_graph_search = best_first_graph_search + # Greedy best-first search is accomplished by specifying f(n) = h(n). + +def astar_search(problem, h=None): + """A* search is best-first graph search with f(n) = g(n)+h(n). + You need to specify the h function when you call astar_search. + Uses the pathmax trick: f(n) = max(f(n), g(n)+h(n)).""" + h = h or problem.h + def f(n): + return max(getattr(n, 'f', -infinity), n.path_cost + h(n)) + return best_first_graph_search(problem, f) + +#______________________________________________________________________________ +## Other search algorithms + +def recursive_best_first_search(problem): + "[Fig. 4.5]" + def RBFS(problem, node, flimit): + if problem.goal_test(node.state): + return node + successors = expand(node, problem) + if len(successors) == 0: + return None, infinity + for s in successors: + s.f = max(s.path_cost + s.h, node.f) + while True: + successors.sort(lambda x,y: x.f - y.f) # Order by lowest f value + best = successors[0] + if best.f > flimit: + return None, best.f + alternative = successors[1] + result, best.f = RBFS(problem, best, min(flimit, alternative)) + if result is not None: + return result + return RBFS(Node(problem.initial), infinity) + + +def hill_climbing(problem): + """From the initial node, keep choosing the neighbor with highest value, + stopping when no neighbor is better. [Fig. 4.11]""" + current = Node(problem.initial) + while True: + neighbor = argmax(expand(node, problem), Node.value) + if neighbor.value() <= current.value(): + return current.state + current = neighbor + +def exp_schedule(k=20, lam=0.005, limit=100): + "One possible schedule function for simulated annealing" + return lambda t: if_(t < limit, k * math.exp(-lam * t), 0) + +def simulated_annealing(problem, schedule=exp_schedule()): + "[Fig. 4.5]" + current = Node(problem.initial) + for t in xrange(sys.maxint): + T = schedule(t) + if T == 0: + return current + next = random.choice(expand(node. problem)) + delta_e = next.path_cost - current.path_cost + if delta_e > 0 or probability(math.exp(delta_e/T)): + current = next + +def online_dfs_agent(a): + "[Fig. 4.12]" + pass #### more + +def lrta_star_agent(a): + "[Fig. 4.12]" + pass #### more + +#______________________________________________________________________________ +# Genetic Algorithm + +def genetic_search(problem, fitness_fn, ngen=1000, pmut=0.0, n=20): + """Call genetic_algorithm on the appropriate parts of a problem. + This requires that the problem has a successor function that generates + reasonable states, and that it has a path_cost function that scores states. + We use the negative of the path_cost function, because costs are to be + minimized, while genetic-algorithm expects a fitness_fn to be maximized.""" + states = [s for (a, s) in problem.successor(problem.initial_state)[:n]] + random.shuffle(states) + fitness_fn = lambda s: - problem.path_cost(0, s, None, s) + return genetic_algorithm(states, fitness_fn, ngen, pmut) + +def genetic_algorithm(population, fitness_fn, ngen=1000, pmut=0.0): + """[Fig. 4.7]""" + def reproduce(p1, p2): + c = random.randrange(len(p1)) + return p1[:c] + p2[c:] + + for i in range(ngen): + new_population = [] + for i in len(population): + p1, p2 = random_weighted_selections(population, 2, fitness_fn) + child = reproduce(p1, p2) + if random.uniform(0,1) > pmut: + child.mutate() + new_population.append(child) + population = new_population + return argmax(population, fitness_fn) + +def random_weighted_selection(seq, n, weight_fn): + """Pick n elements of seq, weighted according to weight_fn. + That is, apply weight_fn to each element of seq, add up the total. + Then choose an element e with probability weight[e]/total. + Repeat n times, with replacement. """ + totals = []; runningtotal = 0 + for item in seq: + runningtotal += weight_fn(item) + totals.append(runningtotal) + selections = [] + for s in range(n): + r = random.uniform(0, totals[-1]) + for i in range(len(seq)): + if totals[i] > r: + selections.append(seq[i]) + break + return selections + + +#_____________________________________________________________________________ +# The remainder of this file implements examples for the search algorithms. + +#______________________________________________________________________________ +# Graphs and Graph Problems + +class Graph: + """A graph connects nodes (verticies) by edges (links). Each edge can also + have a length associated with it. The constructor call is something like: + g = Graph({'A': {'B': 1, 'C': 2}) + this makes a graph with 3 nodes, A, B, and C, with an edge of length 1 from + A to B, and an edge of length 2 from A to C. You can also do: + g = Graph({'A': {'B': 1, 'C': 2}, directed=False) + This makes an undirected graph, so inverse links are also added. The graph + stays undirected; if you add more links with g.connect('B', 'C', 3), then + inverse link is also added. You can use g.nodes() to get a list of nodes, + g.get('A') to get a dict of links out of A, and g.get('A', 'B') to get the + length of the link from A to B. 'Lengths' can actually be any object at + all, and nodes can be any hashable object.""" + + def __init__(self, dict=None, directed=True): + self.dict = dict or {} + self.directed = directed + if not directed: self.make_undirected() + + def make_undirected(self): + "Make a digraph into an undirected graph by adding symmetric edges." + for a in self.dict.keys(): + for (b, distance) in self.dict[a].items(): + self.connect1(b, a, distance) + + def connect(self, A, B, distance=1): + """Add a link from A and B of given distance, and also add the inverse + link if the graph is undirected.""" + self.connect1(A, B, distance) + if not self.directed: self.connect1(B, A, distance) + + def connect1(self, A, B, distance): + "Add a link from A to B of given distance, in one direction only." + self.dict.setdefault(A,{})[B] = distance + + def get(self, a, b=None): + """Return a link distance or a dict of {node: distance} entries. + .get(a,b) returns the distance or None; + .get(a) returns a dict of {node: distance} entries, possibly {}.""" + links = self.dict.setdefault(a, {}) + if b is None: return links + else: return links.get(b) + + def nodes(self): + "Return a list of nodes in the graph." + return self.dict.keys() + +def UndirectedGraph(dict=None): + "Build a Graph where every edge (including future ones) goes both ways." + return Graph(dict=dict, directed=False) + +def RandomGraph(nodes=range(10), min_links=2, width=400, height=300, + curvature=lambda: random.uniform(1.1, 1.5)): + """Construct a random graph, with the specified nodes, and random links. + The nodes are laid out randomly on a (width x height) rectangle. + Then each node is connected to the min_links nearest neighbors. + Because inverse links are added, some nodes will have more connections. + The distance between nodes is the hypotenuse times curvature(), + where curvature() defaults to a random number between 1.1 and 1.5.""" + g = UndirectedGraph() + g.locations = {} + ## Build the cities + for node in nodes: + g.locations[node] = (random.randrange(width), random.randrange(height)) + ## Build roads from each city to at least min_links nearest neighbors. + for i in range(min_links): + for node in nodes: + if len(g.get(node)) < min_links: + here = g.locations[node] + def distance_to_node(n): + if n is node or g.get(node,n): return infinity + return distance(g.locations[n], here) + neighbor = argmin(nodes, distance_to_node) + d = distance(g.locations[neighbor], here) * curvature() + g.connect(node, neighbor, int(d)) + return g + +romania = UndirectedGraph(Dict( + A=Dict(Z=75, S=140, T=118), + B=Dict(U=85, P=101, G=90, F=211), + C=Dict(D=120, R=146, P=138), + D=Dict(M=75), + E=Dict(H=86), + F=Dict(S=99), + H=Dict(U=98), + I=Dict(V=92, N=87), + L=Dict(T=111, M=70), + O=Dict(Z=71, S=151), + P=Dict(R=97), + R=Dict(S=80), + U=Dict(V=142))) +romania.locations = Dict( + A=( 91, 492), B=(400, 327), C=(253, 288), D=(165, 299), + E=(562, 293), F=(305, 449), G=(375, 270), H=(534, 350), + I=(473, 506), L=(165, 379), M=(168, 339), N=(406, 537), + O=(131, 571), P=(320, 368), R=(233, 410), S=(207, 457), + T=( 94, 410), U=(456, 350), V=(509, 444), Z=(108, 531)) + +australia = UndirectedGraph(Dict( + T=Dict(), + SA=Dict(WA=1, NT=1, Q=1, NSW=1, V=1), + NT=Dict(WA=1, Q=1), + NSW=Dict(Q=1, V=1))) +australia.locations = Dict(WA=(120, 24), NT=(135, 20), SA=(135, 30), + Q=(145, 20), NSW=(145, 32), T=(145, 42), V=(145, 37)) + +class GraphProblem(Problem): + "The problem of searching a graph from one node to another." + def __init__(self, initial, goal, graph): + Problem.__init__(self, initial, goal) + self.graph = graph + + def successor(self, A): + "Return a list of (action, result) pairs." + return [(B, B) for B in self.graph.get(A).keys()] + + def path_cost(self, cost_so_far, A, action, B): + return cost_so_far + (self.graph.get(A,B) or infinity) + + def h(self, node): + "h function is straight-line distance from a node's state to goal." + locs = getattr(self.graph, 'locations', None) + if locs: + return int(distance(locs[node.state], locs[self.goal])) + else: + return infinity + +#______________________________________________________________________________ + +#### NOTE: NQueensProblem not working properly yet. + +class NQueensProblem(Problem): + """The problem of placing N queens on an NxN board with none attacking + each other. A state is represented as an N-element array, where the + a value of r in the c-th entry means there is a queen at column c, + row r, and a value of None means that the c-th column has not been + filled in left. We fill in columns left to right.""" + def __init__(self, N): + self.N = N + self.initial = [None] * N + + def successor(self, state): + "In the leftmost empty column, try all non-conflicting rows." + if state[-1] is not None: + return [] ## All columns filled; no successors + else: + def place(col, row): + new = state[:] + new[col] = row + return new + col = state.index(None) + return [(row, place(col, row)) for row in range(self.N) + if not self.conflicted(state, row, col)] + + def conflicted(self, state, row, col): + "Would placing a queen at (row, col) conflict with anything?" + for c in range(col-1): + if self.conflict(row, col, state[c], c): + return True + return False + + def conflict(self, row1, col1, row2, col2): + "Would putting two queens in (row1, col1) and (row2, col2) conflict?" + return (row1 == row2 ## same row + or col1 == col2 ## same column + or row1-col1 == row2-col2 ## same \ diagonal + or row1+col1 == row2+col2) ## same / diagonal + + def goal_test(self, state): + "Check if all columns filled, no conflicts." + if state[-1] is None: + return False + for c in range(len(state)): + if self.conflicted(state, state[c], c): + return False + return True + +#______________________________________________________________________________ +## Inverse Boggle: Search for a high-scoring Boggle board. A good domain for +## iterative-repair and related search tehniques, as suggested by Justin Boyan. + +ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + +cubes16 = ['FORIXB', 'MOQABJ', 'GURILW', 'SETUPL', + 'CMPDAE', 'ACITAO', 'SLCRAE', 'ROMASH', + 'NODESW', 'HEFIYE', 'ONUDTK', 'TEVIGN', + 'ANEDVZ', 'PINESH', 'ABILYT', 'GKYLEU'] + +def random_boggle(n=4): + """Return a random Boggle board of size n x n. + We represent a board as a linear list of letters.""" + cubes = [cubes16[i % 16] for i in range(n*n)] + random.shuffle(cubes) + return map(random.choice, cubes) + +## The best 5x5 board found by Boyan, with our word list this board scores +## 2274 words, for a score of 9837 + +boyan_best = list('RSTCSDEIAEGNLRPEATESMSSID') + +def print_boggle(board): + "Print the board in a 2-d array." + n2 = len(board); n = exact_sqrt(n2) + for i in range(n2): + if i % n == 0: print + if board[i] == 'Q': print 'Qu', + else: print str(board[i]) + ' ', + print + +def boggle_neighbors(n2, cache={}): + """"Return a list of lists, where the i-th element is the list of indexes + for the neighbors of square i.""" + if cache.get(n2): + return cache.get(n2) + n = exact_sqrt(n2) + neighbors = [None] * n2 + for i in range(n2): + neighbors[i] = [] + on_top = i < n + on_bottom = i >= n2 - n + on_left = i % n == 0 + on_right = (i+1) % n == 0 + if not on_top: + neighbors[i].append(i - n) + if not on_left: neighbors[i].append(i - n - 1) + if not on_right: neighbors[i].append(i - n + 1) + if not on_bottom: + neighbors[i].append(i + n) + if not on_left: neighbors[i].append(i + n - 1) + if not on_right: neighbors[i].append(i + n + 1) + if not on_left: neighbors[i].append(i - 1) + if not on_right: neighbors[i].append(i + 1) + cache[n2] = neighbors + return neighbors + +def exact_sqrt(n2): + "If n2 is a perfect square, return its square root, else raise error." + n = int(math.sqrt(n2)) + assert n * n == n2 + return n + +##_____________________________________________________________________________ + +class Wordlist: + """This class holds a list of words. You can use (word in wordlist) + to check if a word is in the list, or wordlist.lookup(prefix) + to see if prefix starts any of the words in the list.""" + def __init__(self, filename, min_len=3): + lines = open(filename).read().upper().split() + self.words = [word for word in lines if len(word) >= min_len] + self.words.sort() + self.bounds = {} + for c in ALPHABET: + c2 = chr(ord(c) + 1) + self.bounds[c] = (bisect.bisect(self.words, c), + bisect.bisect(self.words, c2)) + + def lookup(self, prefix, lo=0, hi=None): + """See if prefix is in dictionary, as a full word or as a prefix. + Return two values: the first is the lowest i such that + words[i].startswith(prefix), or is None; the second is + True iff prefix itself is in the Wordlist.""" + words = self.words + i = bisect.bisect_left(words, prefix, lo, hi) + if i < len(words) and words[i].startswith(prefix): + return i, (words[i] == prefix) + else: + return None, False + + def __contains__(self, word): + return self.words[bisect.bisect_left(self.words, word)] == word + + def __len__(self): + return len(self.words) + +##_____________________________________________________________________________ + +class BoggleFinder: + """A class that allows you to find all the words in a Boggle board. """ + + wordlist = None ## A class variable, holding a wordlist + + def __init__(self, board=None): + if BoggleFinder.wordlist is None: + BoggleFinder.wordlist = Wordlist("../data/wordlist") + self.found = {} + if board: + self.set_board(board) + + def set_board(self, board=None): + "Set the board, and find all the words in it." + if board is None: + board = random_boggle() + self.board = board + self.neighbors = boggle_neighbors(len(board)) + self.found = {} + for i in range(len(board)): + lo, hi = self.wordlist.bounds[board[i]] + self.find(lo, hi, i, [], '') + return self + + def find(self, lo, hi, i, visited, prefix): + """Looking in square i, find the words that continue the prefix, + considering the entries in self.wordlist.words[lo:hi], and not + revisiting the squares in visited.""" + if i in visited: + return + wordpos, is_word = self.wordlist.lookup(prefix, lo, hi) + if wordpos is not None: + if is_word: + self.found[prefix] = True + visited.append(i) + c = self.board[i] + if c == 'Q': c = 'QU' + prefix += c + for j in self.neighbors[i]: + self.find(wordpos, hi, j, visited, prefix) + visited.pop() + + def words(self): + "The words found." + return self.found.keys() + + scores = [0, 0, 0, 0, 1, 2, 3, 5] + [11] * 100 + + def score(self): + "The total score for the words found, according to the rules." + return sum([self.scores[len(w)] for w in self.words()]) + + def __len__(self): + "The number of words found." + return len(self.found) + +##_____________________________________________________________________________ + +def boggle_hill_climbing(board=None, ntimes=100, print_it=True): + """Solve inverse Boggle by hill-climbing: find a high-scoring board by + starting with a random one and changing it.""" + finder = BoggleFinder() + if board is None: + board = random_boggle() + best = len(finder.set_board(board)) + for _ in range(ntimes): + i, oldc = mutate_boggle(board) + new = len(finder.set_board(board)) + if new > best: + best = new + print best, _, board + else: + board[i] = oldc ## Change back + if print_it: + print_boggle(board) + return board, best + +def mutate_boggle(board): + i = random.randrange(len(board)) + oldc = board[i] + board[i] = random.choice(random.choice(cubes16)) ##random.choice(boyan_best) + return i, oldc + +#______________________________________________________________________________ + +## Code to compare searchers on various problems. + +class InstrumentedProblem(Problem): + """Delegates to a problem, and keeps statistics.""" + + def __init__(self, problem): + self.problem = problem + self.succs = self.goal_tests = self.states = 0 + self.found = None + + def successor(self, state): + "Return a list of (action, state) pairs reachable from this state." + result = self.problem.successor(state) + self.succs += 1; self.states += len(result) + return result + + def goal_test(self, state): + "Return true if the state is a goal." + self.goal_tests += 1 + result = self.problem.goal_test(state) + if result: + self.found = state + return result + + def __getattr__(self, attr): + if attr in ('succs', 'goal_tests', 'states'): + return self.__dict__[attr] + else: + return getattr(self.problem, attr) + + def __repr__(self): + return '<%4d/%4d/%4d/%s>' % (self.succs, self.goal_tests, + self.states, str(self.found)[0:4]) + +def compare_searchers(problems, header, searchers=[breadth_first_tree_search, + breadth_first_graph_search, depth_first_graph_search, + iterative_deepening_search, depth_limited_search, + astar_search]): + def do(searcher, problem): + p = InstrumentedProblem(problem) + searcher(p) + return p + table = [[name(s)] + [do(s, p) for p in problems] for s in searchers] + print_table(table, header) + +def compare_graph_searchers(): + compare_searchers(problems=[GraphProblem('A', 'B', romania), + GraphProblem('O', 'N', romania), + GraphProblem('Q', 'WA', australia)], + header=['Searcher', 'Romania(A,B)', 'Romania(O, N)', 'Australia']) + diff --git a/csp/utils.py b/csp/utils.py new file mode 100644 index 00000000..87728c1d --- /dev/null +++ b/csp/utils.py @@ -0,0 +1,714 @@ +"""Provide some widely useful utilities. Safe for "from utils import *". + +""" + +from __future__ import generators +import operator, math, random, copy, sys, os.path, bisect + +#______________________________________________________________________________ +# Compatibility with Python 2.2 and 2.3 + +# The AIMA code is designed to run in Python 2.2 and up (at some point, +# support for 2.2 may go away; 2.2 was released in 2001, and so is over +# 3 years old). The first part of this file brings you up to 2.4 +# compatibility if you are running in Python 2.2 or 2.3: + +try: bool, True, False ## Introduced in 2.3 +except NameError: + class bool(int): + "Simple implementation of Booleans, as in PEP 285" + def __init__(self, val): self.val = val + def __int__(self): return self.val + def __repr__(self): return ('False', 'True')[self.val] + + True, False = bool(1), bool(0) + +try: sum ## Introduced in 2.3 +except NameError: + def sum(seq, start=0): + """Sum the elements of seq. + >>> sum([1, 2, 3]) + 6 + """ + return reduce(operator.add, seq, start) + +try: enumerate ## Introduced in 2.3 +except NameError: + def enumerate(collection): + """Return an iterator that enumerates pairs of (i, c[i]). PEP 279. + >>> list(enumerate('abc')) + [(0, 'a'), (1, 'b'), (2, 'c')] + """ + ## Copied from PEP 279 + i = 0 + it = iter(collection) + while 1: + yield (i, it.next()) + i += 1 + + +try: reversed ## Introduced in 2.4 +except NameError: + def reversed(seq): + """Iterate over x in reverse order. + >>> list(reversed([1,2,3])) + [3, 2, 1] + """ + if hasattr(seq, 'keys'): + raise ValueError("mappings do not support reverse iteration") + i = len(seq) + while i > 0: + i -= 1 + yield seq[i] + + +try: sorted ## Introduced in 2.4 +except NameError: + def sorted(seq, cmp=None, key=None, reverse=False): + """Copy seq and sort and return it. + >>> sorted([3, 1, 2]) + [1, 2, 3] + """ + seq2 = copy.copy(seq) + if key: + if cmp == None: + cmp = __builtins__.cmp + seq2.sort(lambda x,y: cmp(key(x), key(y))) + else: + if cmp == None: + seq2.sort() + else: + seq2.sort(cmp) + if reverse: + seq2.reverse() + return seq2 + +try: + set, frozenset ## set builtin introduced in 2.4 +except NameError: + try: + import sets ## sets module introduced in 2.3 + set, frozenset = sets.Set, sets.ImmutableSet + except (NameError, ImportError): + class BaseSet: + "set type (see http://docs.python.org/lib/types-set.html)" + + + def __init__(self, elements=[]): + self.dict = {} + for e in elements: + self.dict[e] = 1 + + def __len__(self): + return len(self.dict) + + def __iter__(self): + for e in self.dict: + yield e + + def __contains__(self, element): + return element in self.dict + + def issubset(self, other): + for e in self.dict.keys(): + if e not in other: + return False + return True + + def issuperset(self, other): + for e in other: + if e not in self: + return False + return True + + + def union(self, other): + return type(self)(list(self) + list(other)) + + def intersection(self, other): + return type(self)([e for e in self.dict if e in other]) + + def difference(self, other): + return type(self)([e for e in self.dict if e not in other]) + + def symmetric_difference(self, other): + return type(self)([e for e in self.dict if e not in other] + + [e for e in other if e not in self.dict]) + + def copy(self): + return type(self)(self.dict) + + def __repr__(self): + elements = ", ".join(map(str, self.dict)) + return "%s([%s])" % (type(self).__name__, elements) + + __le__ = issubset + __ge__ = issuperset + __or__ = union + __and__ = intersection + __sub__ = difference + __xor__ = symmetric_difference + + class frozenset(BaseSet): + "A frozenset is a BaseSet that has a hash value and is immutable." + + def __init__(self, elements=[]): + BaseSet.__init__(elements) + self.hash = 0 + for e in self: + self.hash |= hash(e) + + def __hash__(self): + return self.hash + + class set(BaseSet): + "A set is a BaseSet that does not have a hash, but is mutable." + + def update(self, other): + for e in other: + self.add(e) + return self + + def intersection_update(self, other): + for e in self.dict.keys(): + if e not in other: + self.remove(e) + return self + + def difference_update(self, other): + for e in self.dict.keys(): + if e in other: + self.remove(e) + return self + + def symmetric_difference_update(self, other): + to_remove1 = [e for e in self.dict if e in other] + to_remove2 = [e for e in other if e in self.dict] + self.difference_update(to_remove1) + self.difference_update(to_remove2) + return self + + def add(self, element): + self.dict[element] = 1 + + def remove(self, element): + del self.dict[element] + + def discard(self, element): + if element in self.dict: + del self.dict[element] + + def pop(self): + key, val = self.dict.popitem() + return key + + def clear(self): + self.dict.clear() + + __ior__ = update + __iand__ = intersection_update + __isub__ = difference_update + __ixor__ = symmetric_difference_update + + + + +#______________________________________________________________________________ +# Simple Data Structures: infinity, Dict, Struct + +infinity = 1.0e400 + +def Dict(**entries): + """Create a dict out of the argument=value arguments. + >>> Dict(a=1, b=2, c=3) + {'a': 1, 'c': 3, 'b': 2} + """ + return entries + +class DefaultDict(dict): + """Dictionary with a default value for unknown keys.""" + def __init__(self, default): + self.default = default + + def __getitem__(self, key): + if key in self: return self.get(key) + return self.setdefault(key, copy.deepcopy(self.default)) + + def __copy__(self): + copy = DefaultDict(self.default) + copy.update(self) + return copy + +class Struct: + """Create an instance with argument=value slots. + This is for making a lightweight object whose class doesn't matter.""" + def __init__(self, **entries): + self.__dict__.update(entries) + + def __cmp__(self, other): + if isinstance(other, Struct): + return cmp(self.__dict__, other.__dict__) + else: + return cmp(self.__dict__, other) + + def __repr__(self): + args = ['%s=%s' % (k, repr(v)) for (k, v) in vars(self).items()] + return 'Struct(%s)' % ', '.join(args) + +def update(x, **entries): + """Update a dict; or an object with slots; according to entries. + >>> update({'a': 1}, a=10, b=20) + {'a': 10, 'b': 20} + >>> update(Struct(a=1), a=10, b=20) + Struct(a=10, b=20) + """ + if isinstance(x, dict): + x.update(entries) + else: + x.__dict__.update(entries) + return x + +#______________________________________________________________________________ +# Functions on Sequences (mostly inspired by Common Lisp) +# NOTE: Sequence functions (count_if, find_if, every, some) take function +# argument first (like reduce, filter, and map). + +def removeall(item, seq): + """Return a copy of seq (or string) with all occurences of item removed. + >>> removeall(3, [1, 2, 3, 3, 2, 1, 3]) + [1, 2, 2, 1] + >>> removeall(4, [1, 2, 3]) + [1, 2, 3] + """ + if isinstance(seq, str): + return seq.replace(item, '') + else: + return [x for x in seq if x != item] + +def unique(seq): + """Remove duplicate elements from seq. Assumes hashable elements. + >>> unique([1, 2, 3, 2, 1]) + [1, 2, 3] + """ + return list(set(seq)) + +def product(numbers): + """Return the product of the numbers. + >>> product([1,2,3,4]) + 24 + """ + return reduce(operator.mul, numbers, 1) + +def count_if(predicate, seq): + """Count the number of elements of seq for which the predicate is true. + >>> count_if(callable, [42, None, max, min]) + 2 + """ + f = lambda count, x: count + (not not predicate(x)) + return reduce(f, seq, 0) + +def find_if(predicate, seq): + """If there is an element of seq that satisfies predicate; return it. + >>> find_if(callable, [3, min, max]) + + >>> find_if(callable, [1, 2, 3]) + """ + for x in seq: + if predicate(x): return x + return None + +def every(predicate, seq): + """True if every element of seq satisfies predicate. + >>> every(callable, [min, max]) + 1 + >>> every(callable, [min, 3]) + 0 + """ + for x in seq: + if not predicate(x): return False + return True + +def some(predicate, seq): + """If some element x of seq satisfies predicate(x), return predicate(x). + >>> some(callable, [min, 3]) + 1 + >>> some(callable, [2, 3]) + 0 + """ + for x in seq: + px = predicate(x) + if px: return px + return False + +def isin(elt, seq): + """Like (elt in seq), but compares with is, not ==. + >>> e = []; isin(e, [1, e, 3]) + True + >>> isin(e, [1, [], 3]) + False + """ + for x in seq: + if elt is x: return True + return False + +#______________________________________________________________________________ +# Functions on sequences of numbers +# NOTE: these take the sequence argument first, like min and max, +# and like standard math notation: \sigma (i = 1..n) fn(i) +# A lot of programing is finding the best value that satisfies some condition; +# so there are three versions of argmin/argmax, depending on what you want to +# do with ties: return the first one, return them all, or pick at random. + + +def argmin(seq, fn): + """Return an element with lowest fn(seq[i]) score; tie goes to first one. + >>> argmin(['one', 'to', 'three'], len) + 'to' + """ + best = seq[0]; best_score = fn(best) + for x in seq: + x_score = fn(x) + if x_score < best_score: + best, best_score = x, x_score + return best + +def argmin_list(seq, fn): + """Return a list of elements of seq[i] with the lowest fn(seq[i]) scores. + >>> argmin_list(['one', 'to', 'three', 'or'], len) + ['to', 'or'] + """ + best_score, best = fn(seq[0]), [] + for x in seq: + x_score = fn(x) + if x_score < best_score: + best, best_score = [x], x_score + elif x_score == best_score: + best.append(x) + return best + +def argmin_random_tie(seq, fn): + """Return an element with lowest fn(seq[i]) score; break ties at random. + Thus, for all s,f: argmin_random_tie(s, f) in argmin_list(s, f)""" + best_score = fn(seq[0]); n = 0 + for x in seq: + x_score = fn(x) + if x_score < best_score: + best, best_score = x, x_score; n = 1 + elif x_score == best_score: + n += 1 + if random.randrange(n) == 0: + best = x + return best + +def argmax(seq, fn): + """Return an element with highest fn(seq[i]) score; tie goes to first one. + >>> argmax(['one', 'to', 'three'], len) + 'three' + """ + return argmin(seq, lambda x: -fn(x)) + +def argmax_list(seq, fn): + """Return a list of elements of seq[i] with the highest fn(seq[i]) scores. + >>> argmax_list(['one', 'three', 'seven'], len) + ['three', 'seven'] + """ + return argmin_list(seq, lambda x: -fn(x)) + +def argmax_random_tie(seq, fn): + "Return an element with highest fn(seq[i]) score; break ties at random." + return argmin_random_tie(seq, lambda x: -fn(x)) +#______________________________________________________________________________ +# Statistical and mathematical functions + +def histogram(values, mode=0, bin_function=None): + """Return a list of (value, count) pairs, summarizing the input values. + Sorted by increasing value, or if mode=1, by decreasing count. + If bin_function is given, map it over values first.""" + if bin_function: values = map(bin_function, values) + bins = {} + for val in values: + bins[val] = bins.get(val, 0) + 1 + if mode: + return sorted(bins.items(), key=lambda v: v[1], reverse=True) + else: + return sorted(bins.items()) + +def log2(x): + """Base 2 logarithm. + >>> log2(1024) + 10.0 + """ + return math.log10(x) / math.log10(2) + +def mode(values): + """Return the most common value in the list of values. + >>> mode([1, 2, 3, 2]) + 2 + """ + return histogram(values, mode=1)[0][0] + +def median(values): + """Return the middle value, when the values are sorted. + If there are an odd number of elements, try to average the middle two. + If they can't be averaged (e.g. they are strings), choose one at random. + >>> median([10, 100, 11]) + 11 + >>> median([1, 2, 3, 4]) + 2.5 + """ + n = len(values) + values = sorted(values) + if n % 2 == 1: + return values[n/2] + else: + middle2 = values[(n/2)-1:(n/2)+1] + try: + return mean(middle2) + except TypeError: + return random.choice(middle2) + +def mean(values): + """Return the arithmetic average of the values.""" + return sum(values) / float(len(values)) + +def stddev(values, meanval=None): + """The standard deviation of a set of values. + Pass in the mean if you already know it.""" + if meanval == None: meanval = mean(values) + return math.sqrt(sum([(x - meanval)**2 for x in values]) / (len(values)-1)) + +def dotproduct(X, Y): + """Return the sum of the element-wise product of vectors x and y. + >>> dotproduct([1, 2, 3], [1000, 100, 10]) + 1230 + """ + return sum([x * y for x, y in zip(X, Y)]) + +def vector_add(a, b): + """Component-wise addition of two vectors. + >>> vector_add((0, 1), (8, 9)) + (8, 10) + """ + return tuple(map(operator.add, a, b)) + +def probability(p): + "Return true with probability p." + return p > random.uniform(0.0, 1.0) + +def num_or_str(x): + """The argument is a string; convert to a number if possible, or strip it. + >>> num_or_str('42') + 42 + >>> num_or_str(' 42x ') + '42x' + """ + if isnumber(x): return x + try: + return int(x) + except ValueError: + try: + return float(x) + except ValueError: + return str(x).strip() + +def normalize(numbers, total=1.0): + """Multiply each number by a constant such that the sum is 1.0 (or total). + >>> normalize([1,2,1]) + [0.25, 0.5, 0.25] + """ + k = total / sum(numbers) + return [k * n for n in numbers] + +## OK, the following are not as widely useful utilities as some of the other +## functions here, but they do show up wherever we have 2D grids: Wumpus and +## Vacuum worlds, TicTacToe and Checkers, and markov decision Processes. + +orientations = [(1,0), (0, 1), (-1, 0), (0, -1)] + +def turn_right(orientation): + return orientations[orientations.index(orientation)-1] + +def turn_left(orientation): + return orientations[(orientations.index(orientation)+1) % len(orientations)] + +def distance((ax, ay), (bx, by)): + "The distance between two (x, y) points." + return math.hypot((ax - bx), (ay - by)) + +def distance2((ax, ay), (bx, by)): + "The square of the distance between two (x, y) points." + return (ax - bx)**2 + (ay - by)**2 + +def clip(vector, lowest, highest): + """Return vector, except if any element is less than the corresponding + value of lowest or more than the corresponding value of highest, clip to + those values. + >>> clip((-1, 10), (0, 0), (9, 9)) + (0, 9) + """ + return type(vector)(map(min, map(max, vector, lowest), highest)) +#______________________________________________________________________________ +# Misc Functions + +def printf(format, *args): + """Format args with the first argument as format string, and write. + Return the last arg, or format itself if there are no args.""" + sys.stdout.write(str(format) % args) + return if_(args, args[-1], format) + +def caller(n=1): + """Return the name of the calling function n levels up in the frame stack. + >>> caller(0) + 'caller' + >>> def f(): + ... return caller() + >>> f() + 'f' + """ + import inspect + return inspect.getouterframes(inspect.currentframe())[n][3] + +def memoize(fn, slot=None): + """Memoize fn: make it remember the computed value for any argument list. + If slot is specified, store result in that slot of first argument. + If slot is false, store results in a dictionary.""" + if slot: + def memoized_fn(obj, *args): + if hasattr(obj, slot): + return getattr(obj, slot) + else: + val = fn(obj, *args) + setattr(obj, slot, val) + return val + else: + def memoized_fn(*args): + if not memoized_fn.cache.has_key(args): + memoized_fn.cache[args] = fn(*args) + return memoized_fn.cache[args] + memoized_fn.cache = {} + return memoized_fn + +def if_(test, result, alternative): + """Like C++ and Java's (test ? result : alternative), except + both result and alternative are always evaluated. However, if + either evaluates to a function, it is applied to the empty arglist, + so you can delay execution by putting it in a lambda. + >>> if_(2 + 2 == 4, 'ok', lambda: expensive_computation()) + 'ok' + """ + if test: + if callable(result): return result() + return result + else: + if callable(alternative): return alternative() + return alternative + +def name(object): + "Try to find some reasonable name for the object." + return (getattr(object, 'name', 0) or getattr(object, '__name__', 0) + or getattr(getattr(object, '__class__', 0), '__name__', 0) + or str(object)) + +def isnumber(x): + "Is x a number? We say it is if it has a __int__ method." + return hasattr(x, '__int__') + +def issequence(x): + "Is x a sequence? We say it is if it has a __getitem__ method." + return hasattr(x, '__getitem__') + +def print_table(table, header=None, sep=' ', numfmt='%g'): + """Print a list of lists as a table, so that columns line up nicely. + header, if specified, will be printed as the first row. + numfmt is the format for all numbers; you might want e.g. '%6.2f'. + (If you want different formats in differnt columns, don't use print_table.) + sep is the separator between columns.""" + justs = [if_(isnumber(x), 'rjust', 'ljust') for x in table[0]] + if header: + table = [header] + table + table = [[if_(isnumber(x), lambda: numfmt % x, x) for x in row] + for row in table] + maxlen = lambda seq: max(map(len, seq)) + sizes = map(maxlen, zip(*[map(str, row) for row in table])) + for row in table: + for (j, size, x) in zip(justs, sizes, row): + print getattr(str(x), j)(size), sep, + print + +def AIMAFile(components, mode='r'): + "Open a file based at the AIMA root directory." + import utils + dir = os.path.dirname(utils.__file__) + return open(apply(os.path.join, [dir] + components), mode) + +def DataFile(name, mode='r'): + "Return a file in the AIMA /data directory." + return AIMAFile(['..', 'data', name], mode) + + +#______________________________________________________________________________ +# Queues: Stack, FIFOQueue, PriorityQueue + +class Queue: + """Queue is an abstract class/interface. There are three types: + Stack(): A Last In First Out Queue. + FIFOQueue(): A First In First Out Queue. + PriorityQueue(lt): Queue where items are sorted by lt, (default <). + Each type supports the following methods and functions: + q.append(item) -- add an item to the queue + q.extend(items) -- equivalent to: for item in items: q.append(item) + q.pop() -- return the top item from the queue + len(q) -- number of items in q (also q.__len()) + Note that isinstance(Stack(), Queue) is false, because we implement stacks + as lists. If Python ever gets interfaces, Queue will be an interface.""" + + def __init__(self): + abstract + + def extend(self, items): + for item in items: self.append(item) + +def Stack(): + """Return an empty list, suitable as a Last-In-First-Out Queue.""" + return [] + +class FIFOQueue(Queue): + """A First-In-First-Out Queue.""" + def __init__(self): + self.A = []; self.start = 0 + def append(self, item): + self.A.append(item) + def __len__(self): + return len(self.A) - self.start + def extend(self, items): + self.A.extend(items) + def pop(self): + e = self.A[self.start] + self.start += 1 + if self.start > 5 and self.start > len(self.A)/2: + self.A = self.A[self.start:] + self.start = 0 + return e + +class PriorityQueue(Queue): + """A queue in which the minimum (or maximum) element (as determined by f and + order) is returned first. If order is min, the item with minimum f(x) is + returned first; if order is max, then it is the item with maximum f(x).""" + def __init__(self, order=min, f=lambda x: x): + update(self, A=[], order=order, f=f) + def append(self, item): + bisect.insort(self.A, (self.f(item), item)) + def __len__(self): + return len(self.A) + def pop(self): + if self.order == min: + return self.A.pop(0)[1] + else: + return self.A.pop()[1] + +## Fig: The idea is we can define things like Fig[3,10] later. +## Alas, it is Fig[3,10] not Fig[3.10], because that would be the same as Fig[3.1] +Fig = {} + + +