This is the python source code of for post Reinforcement Learning Example for Planning Tasks Using Q Learning and Dyna-Q
""" This part of code is the Dyna-Q learning brain, which allows agent to make decision. All decisions and learning processes are made in here. (MIT license) """ import numpy as np import pandas as pd from copy import deepcopy class QLearningTable: def __init__(self, actions, learning_rate=0.1, reward_decay=0.9, e_greedy=0.9, agent=""): self.actions = actions # a list = learning_rate self.gamma = reward_decay self.epsilon = e_greedy self.agent=agent self.q_table = pd.DataFrame(columns=self.actions) def choose_action(self, observation): self.check_state_exist(observation) # action selection if self.agent == "RANDOM_AGENT": action = np.random.choice(self.actions) return action if np.random.uniform() < self.epsilon: # choose best action state_action = self.q_table.ix[observation, :] state_action = state_action.reindex(np.random.permutation(state_action.index)) max_value=0 for act in list(self.q_table.columns.values): if self.q_table.ix[observation, act] >= max_value : max_action= act max_value= self.q_table.ix[observation, act] action=max_action else: # choose random action action = np.random.choice(self.actions) return action def learn(self, s, a, r, s_, dn): self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal' and dn != True: q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal else: q_target = r # next state is terminal self.q_table.ix[s, a] += * (q_target - q_predict) # update def check_state_exist(self, state): if state not in self.q_table.index: # append new state to q table self.q_table = self.q_table.append( pd.Series( [0]*len(self.actions), index=self.q_table.columns, name=state, ) ) class EnvModel: """Similar to the memory buffer in DQN, you can store past experiences in here. Alternatively, the model can generate next state and reward signal accurately.""" def __init__(self, actions): self.actions = actions self.database = pd.DataFrame(columns=actions, dtype=np.object) def store_transition(self, s, a, r, s_): if s not in self.database.index: self.database = self.database.append( pd.Series( [None] * len(self.actions), index=self.database.columns, name=s, ))[s, a] = deepcopy((r, s_)) def sample_s_a(self): s = np.random.choice(self.database.index) a = np.random.choice(self.database.ix[s].dropna().index) # filter out the None value return s, a def get_r_s_(self, s, a): r, s_ = self.database.ix[s, a] return r, s_ def get_env(self): print (self.database)