This is the python source code of RL_brain.py for post Reinforcement Learning Example for Planning Tasks Using Q Learning and Dyna-Q
"""
This part of code is the Dyna-Q learning brain, which allows agent to make decision.
All decisions and learning processes are made in here.
(MIT license)
"""
import numpy as np
import pandas as pd
from copy import deepcopy
class QLearningTable:
def __init__(self, actions, learning_rate=0.1, reward_decay=0.9, e_greedy=0.9, agent=""):
self.actions = actions # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.agent=agent
self.q_table = pd.DataFrame(columns=self.actions)
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if self.agent == "RANDOM_AGENT":
action = np.random.choice(self.actions)
return action
if np.random.uniform() < self.epsilon:
# choose best action
state_action = self.q_table.ix[observation, :]
state_action = state_action.reindex(np.random.permutation(state_action.index))
max_value=0
for act in list(self.q_table.columns.values):
if self.q_table.ix[observation, act] >= max_value :
max_action= act
max_value= self.q_table.ix[observation, act]
action=max_action
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_, dn):
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a]
if s_ != 'terminal' and dn != True:
q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
class EnvModel:
"""Similar to the memory buffer in DQN, you can store past experiences in here.
Alternatively, the model can generate next state and reward signal accurately."""
def __init__(self, actions):
self.actions = actions
self.database = pd.DataFrame(columns=actions, dtype=np.object)
def store_transition(self, s, a, r, s_):
if s not in self.database.index:
self.database = self.database.append(
pd.Series(
[None] * len(self.actions),
index=self.database.columns,
name=s,
))
self.database.at[s, a] = deepcopy((r, s_))
def sample_s_a(self):
s = np.random.choice(self.database.index)
a = np.random.choice(self.database.ix[s].dropna().index) # filter out the None value
return s, a
def get_r_s_(self, s, a):
r, s_ = self.database.ix[s, a]
return r, s_
def get_env(self):
print (self.database)