Machine Learning/Algorithm

Q-learning grid world 예제 코드

바나나인간 2019. 9. 8. 03:24

Processing을 이용해 q learning을 이용한 grid world 예제를 작성해보았다.

 

 

 

 

 

 

 

QLearningAgent.pde

import grafica.*;

int cellSize = 50;
Cell[][] grid;

PVector goal;
PVector agent;

int gridSize = 15;

int trainStep = 0;
int moveStep = 0;
boolean reset = true;

PrintWriter output;

int nPoints = 500;
GPointsArray points = new GPointsArray(nPoints);
GPlot plot;

void setup() {
  size(1000, 700, P3D);
    
  grid = new Cell[gridSize][gridSize];
  
  for(int i = 0; i < gridSize; i++){
    for(int j = 0; j < gridSize; j++){
      grid[i][j] = new Cell(i*cellSize, j*cellSize, cellSize);
    }
  }
  
  goal = new PVector(int(gridSize/2), int(gridSize/2));
  for(int i = 0; i < gridSize; i++){
    for(int j = 0; j < gridSize; j++){
      if(j == 0){
        grid[i][j].reward[3] = -1;
      } else if(j == gridSize - 1){
        grid[i][j].reward[1] = -1;
      }
      if(i == 0){
        grid[i][j].reward[2] = -1;
      } else if(i == gridSize - 1){
        grid[i][j].reward[0] = -1;
      }
    }
  }

  grid[int(goal.x) - 1][int(goal.y)].reward[0] = 1000;
  grid[int(goal.x)][int(goal.y) - 1].reward[1] = 1000;
  grid[int(goal.x) + 1][int(goal.y)].reward[2] = 1000;
  grid[int(goal.x)][int(goal.y) + 1].reward[3] = 1000;
  grid[int(goal.x)][int(goal.y)].goal = true;

  plot = new GPlot(this);

  plot.setPos(5, 5);
  plot.setTitleText("Training result");
  plot.getXAxis().setAxisLabelText("Episodes");
  plot.getYAxis().setAxisLabelText("Move steps");
  plot.setDim(150, 150);
  plot.activatePanning();
  
}

void draw() {
  background(0);
  if (reset == true) {
    if(trainStep > 500){
      noLoop();
    }
    points.add(trainStep, moveStep);
    plot.setPoints(points);
    moveStep = 0;
    trainStep++;
    agent = new PVector(int(random(0, gridSize-1)), int(random(0, gridSize-1)));
    reset = false;
  }

  int move = 0;
  do move = int(random(4));
  while(grid[int(agent.x)][int(agent.y)].reward[move] == -1);

  for(int i = 0; i < 4; i++) {
    if (grid[int(agent.x)][int(agent.y)].qValue[i] > grid[int(agent.x)][int(agent.y)].qValue[move])
      move = i;
  }

  PVector nextAgent = new PVector(agent.x, agent.y);

  if(move == 0){
    nextAgent.x++;
  } else if(move == 1){
    nextAgent.y++;
  } else if(move == 2){
    nextAgent.x--;
  } else{
    nextAgent.y--;
  }
  
  moveStep++;

  int nextQvalue = 0;
  for(int i = 0; i < 4; i++) {
    if(grid[int(nextAgent.x)][int(nextAgent.y)].qValue[i] > nextQvalue)
      nextQvalue = grid[int(nextAgent.x)][int(nextAgent.y)].qValue[i];
  }

  grid[int(agent.x)][int(agent.y)].qValue[move] = grid[int(agent.x)][int(agent.y)].reward[move] + int(0.5 * nextQvalue);
  agent = nextAgent;

  if (goal.x == agent.x && goal.y == agent.y){
        reset = true;
  }
  grid[int(agent.x)][int(agent.y)].agent = true;
  
  pushMatrix();
  camera(width/2, 1100, (height/2) / tan(PI/6), width/2, height/2, 0, 0, 1, 0);
  translate(200, 0, 0);
  for ( int i = 0; i < gridSize; i++) {
    for ( int j = 0; j < gridSize; j++) {
      grid[i][j].display();
    }
  }
  popMatrix();
  
  grid[int(agent.x)][int(agent.y)].agent = false;
  
  plot.beginDraw();
  plot.drawBackground();
  plot.drawBox();
  plot.drawXAxis();
  plot.drawYAxis();
  plot.drawTopAxis();
  plot.drawRightAxis();
  plot.drawTitle();
  plot.getMainLayer().drawPoints();
  plot.endDraw();
}

 

Cell.pde

class Cell{
  int x;
  int y;
  int size;
  boolean goal;
  boolean agent;
  
  int[] reward;
  int[] qValue;
  
  Cell (int x_, int y_, int size_){
    x = x_;
    y = y_;
    size = size_;
    goal = false;
    agent = false;
    
    reward = new int[4];
    qValue = new int[4];

    for(int i = 0; i < 4; i++){
      reward[i] = 0;
      qValue[i] = 0;
    }
  }
  void display(){
    stroke(200);
    if(goal){
      fill(100, 255, 100);
      rect(x, y, size, size);
    } else if(agent){
      fill(255, 100, 100);
      pushMatrix();
      translate(x + size/2, y + size/2);
      box(size);
      popMatrix();
      rect(x, y, size, size);
    } else{
      fill(255);
      rect(x, y, size, size);
    }
  }
}