Machine Learning/Algorithm
Q-learning grid world 예제 코드
바나나인간
2019. 9. 8. 03:24
Processing을 이용해 q learning을 이용한 grid world 예제를 작성해보았다.
QLearningAgent.pde
import grafica.*;
int cellSize = 50;
Cell[][] grid;
PVector goal;
PVector agent;
int gridSize = 15;
int trainStep = 0;
int moveStep = 0;
boolean reset = true;
PrintWriter output;
int nPoints = 500;
GPointsArray points = new GPointsArray(nPoints);
GPlot plot;
void setup() {
size(1000, 700, P3D);
grid = new Cell[gridSize][gridSize];
for(int i = 0; i < gridSize; i++){
for(int j = 0; j < gridSize; j++){
grid[i][j] = new Cell(i*cellSize, j*cellSize, cellSize);
}
}
goal = new PVector(int(gridSize/2), int(gridSize/2));
for(int i = 0; i < gridSize; i++){
for(int j = 0; j < gridSize; j++){
if(j == 0){
grid[i][j].reward[3] = -1;
} else if(j == gridSize - 1){
grid[i][j].reward[1] = -1;
}
if(i == 0){
grid[i][j].reward[2] = -1;
} else if(i == gridSize - 1){
grid[i][j].reward[0] = -1;
}
}
}
grid[int(goal.x) - 1][int(goal.y)].reward[0] = 1000;
grid[int(goal.x)][int(goal.y) - 1].reward[1] = 1000;
grid[int(goal.x) + 1][int(goal.y)].reward[2] = 1000;
grid[int(goal.x)][int(goal.y) + 1].reward[3] = 1000;
grid[int(goal.x)][int(goal.y)].goal = true;
plot = new GPlot(this);
plot.setPos(5, 5);
plot.setTitleText("Training result");
plot.getXAxis().setAxisLabelText("Episodes");
plot.getYAxis().setAxisLabelText("Move steps");
plot.setDim(150, 150);
plot.activatePanning();
}
void draw() {
background(0);
if (reset == true) {
if(trainStep > 500){
noLoop();
}
points.add(trainStep, moveStep);
plot.setPoints(points);
moveStep = 0;
trainStep++;
agent = new PVector(int(random(0, gridSize-1)), int(random(0, gridSize-1)));
reset = false;
}
int move = 0;
do move = int(random(4));
while(grid[int(agent.x)][int(agent.y)].reward[move] == -1);
for(int i = 0; i < 4; i++) {
if (grid[int(agent.x)][int(agent.y)].qValue[i] > grid[int(agent.x)][int(agent.y)].qValue[move])
move = i;
}
PVector nextAgent = new PVector(agent.x, agent.y);
if(move == 0){
nextAgent.x++;
} else if(move == 1){
nextAgent.y++;
} else if(move == 2){
nextAgent.x--;
} else{
nextAgent.y--;
}
moveStep++;
int nextQvalue = 0;
for(int i = 0; i < 4; i++) {
if(grid[int(nextAgent.x)][int(nextAgent.y)].qValue[i] > nextQvalue)
nextQvalue = grid[int(nextAgent.x)][int(nextAgent.y)].qValue[i];
}
grid[int(agent.x)][int(agent.y)].qValue[move] = grid[int(agent.x)][int(agent.y)].reward[move] + int(0.5 * nextQvalue);
agent = nextAgent;
if (goal.x == agent.x && goal.y == agent.y){
reset = true;
}
grid[int(agent.x)][int(agent.y)].agent = true;
pushMatrix();
camera(width/2, 1100, (height/2) / tan(PI/6), width/2, height/2, 0, 0, 1, 0);
translate(200, 0, 0);
for ( int i = 0; i < gridSize; i++) {
for ( int j = 0; j < gridSize; j++) {
grid[i][j].display();
}
}
popMatrix();
grid[int(agent.x)][int(agent.y)].agent = false;
plot.beginDraw();
plot.drawBackground();
plot.drawBox();
plot.drawXAxis();
plot.drawYAxis();
plot.drawTopAxis();
plot.drawRightAxis();
plot.drawTitle();
plot.getMainLayer().drawPoints();
plot.endDraw();
}
Cell.pde
class Cell{
int x;
int y;
int size;
boolean goal;
boolean agent;
int[] reward;
int[] qValue;
Cell (int x_, int y_, int size_){
x = x_;
y = y_;
size = size_;
goal = false;
agent = false;
reward = new int[4];
qValue = new int[4];
for(int i = 0; i < 4; i++){
reward[i] = 0;
qValue[i] = 0;
}
}
void display(){
stroke(200);
if(goal){
fill(100, 255, 100);
rect(x, y, size, size);
} else if(agent){
fill(255, 100, 100);
pushMatrix();
translate(x + size/2, y + size/2);
box(size);
popMatrix();
rect(x, y, size, size);
} else{
fill(255);
rect(x, y, size, size);
}
}
}