Last active
December 21, 2016 09:06
-
-
Save jewel12/f9568c5cee16d834a6c148d7ee1fcb08 to your computer and use it in GitHub Desktop.
ソートしてくれ頼む
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import random | |
import difflib | |
import json | |
import csv | |
import sort_runner as sr | |
GAMMA = 0.8 | |
EPSILON_INIT = 1 | |
class QL: | |
@classmethod | |
def with_learned_model(self, csv_path): | |
q = {} | |
with open(csv_path, 'r') as f: | |
reader = csv.reader(f, delimiter='\t') | |
for row in reader: | |
if not row[0] in q: | |
q[row[0]] = {} | |
q[row[0]][eval('sr.%s' % row[1])] = float(row[2]) | |
ql = QL() | |
ql.q = q | |
return ql | |
def __init__(self): | |
self.q = {} | |
def learn(self): | |
for i in range(100): | |
print('episode %d' % i) | |
self.do_episode() | |
f = open('model.csv', 'w') | |
for s, av in self.q.items(): | |
for a, v in av.items(): | |
f.write('%s\t%s\t%f\n' % (s, a, v)) | |
f.close() | |
def do_episode(self): | |
runner = sr.Runner() | |
epsilon = EPSILON_INIT | |
actions = sr.actions | |
while not runner.finished(): | |
current = runner.env.current_exp | |
a = self._choice_action(current, actions, epsilon) | |
runner.step(a) | |
next_state = runner.env.current_exp | |
q_val = self._reward(runner) + (GAMMA * self._max_q(next_state, actions)) | |
self._set_q_val(current, a, q_val) | |
if epsilon > 0: | |
epsilon -= 0.000001 | |
def _choice_action(self, current, actions, epsilon, greedy=False, verbose=False): | |
if not current in self.q: | |
return random.choice(actions) | |
if greedy or random.random() > epsilon: | |
max_q = max(self.q[current].values()) | |
next_action_candidates = [] | |
for a, q in self.q[current].items(): | |
if q == max_q: | |
next_action_candidates.append(a) | |
if verbose: | |
print('max_q: %f' % max_q) | |
return random.choice(next_action_candidates) | |
return random.choice(actions) | |
def _set_q_val(self, current, action, q_val): | |
if not current in self.q: | |
self.q[current] = {} | |
self.q[current][action] = q_val | |
def _reward(self, runner): | |
if runner.finished(): | |
return 1 | |
return 0 | |
# def _reward(self, runner): | |
# r = 0 | |
# if runner.finished(): | |
# return 100 | |
# else: | |
# for i in range(len(runner.env.collect_answer)): | |
# ci = runner.env.current[i] | |
# if ci != runner.env.old_state[i] and ci == runner.env.collect_answer[i]: | |
# r += 1 | |
# return r | |
def _max_q(self, state, actions): | |
if not state in self.q: | |
return 0 | |
return max(self.q[state].get(a, 0) for a in actions) | |
def run(self, input): | |
runner = sr.Runner(input) | |
actions = sr.actions | |
while not runner.finished(): | |
a = self._choice_action(runner.env.current_exp, actions, 0, greedy=True, verbose=True) | |
runner.step(a, verbose=True) | |
print('-----------') | |
ql = QL() | |
ql.learn() | |
# import sys | |
# ql = QL.with_learned_model(sys.argv[1]) | |
# ql.run([int(s) for s in sys.argv[2].split(',')]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import random | |
import copy | |
TARGET_ELEMENT_SIZE = 10 | |
class Runner: | |
def __init__(self, input=None): | |
self.env = SortEnvironment(range(TARGET_ELEMENT_SIZE), input) | |
self.step_num = 0 | |
def step(self, action, verbose=False): | |
action.do(self.env) | |
if verbose: | |
print('step %d, %s' % (self.step_num, action)) | |
print('%s -> %s' % (self.env.old_state, self.env.current)) | |
self.step_num += 1 | |
def finished(self): | |
return self.env.collected() | |
class SortEnvironment: | |
def __init__(self, collect_answer, input=None): | |
if not input: | |
self.current = random.sample(collect_answer, len(collect_answer)) | |
else: | |
self.current = input | |
self.collect_answer = collect_answer | |
self.old_state = copy.copy(self.current) | |
@property | |
def current_exp(self): | |
return ','.join([str(s) for s in self.current]) | |
@property | |
def old_state(self): | |
return self.old_state | |
@property | |
def collect_answer(self): | |
return self.collect_answer | |
def collected(self): | |
return self.current == self.collect_answer | |
def biggerThan(self, pos_a, pos_b): | |
return self.current[pos_a] > self.current[pos_b] | |
def swap(self, pos_a, pos_b): | |
self.old_state = copy.copy(self.current) | |
a = self.current[pos_a] | |
b = self.current[pos_b] | |
self.current[pos_a] = b | |
self.current[pos_b] = a | |
class SwapPosAAndPosB: | |
def __init__(self, pos_a, pos_b): | |
self.pos_a = pos_a | |
self.pos_b = pos_b | |
def do(self, env): | |
env.swap(self.pos_a, self.pos_b) | |
def __repr__(self): | |
return 'SwapPosAAndPosB(%d,%d)' % (self.pos_a, self.pos_b) | |
actions = [] | |
for i in range(TARGET_ELEMENT_SIZE): | |
actions += [ SwapPosAAndPosB(i,j) for j in range(TARGET_ELEMENT_SIZE) if i != j and i > j ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment