Skip to content

Commit 460d882

Browse files
committedJan 6, 2017
increase TF version compatibility from 0.8.0 to 0.12.0
1 parent aceeade commit 460d882

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed
 

‎README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ This implementation contains:
1818
- Python 2.7 or Python 3.3+
1919
- [gym](https://github.com/openai/gym)
2020
- [tqdm](https://github.com/tqdm/tqdm)
21-
- [OpenCV2](http://opencv.org/)
22-
- [TensorFlow 0.8.0](https://github.com/tensorflow/tensorflow/tree/r0.8)
21+
- [SciPy](http://www.scipy.org/install.html) or [OpenCV2](http://opencv.org/)
22+
- [TensorFlow 0.12.0](https://github.com/tensorflow/tensorflow/tree/r0.12)
2323

2424

2525
## Usage

‎dqn/agent.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def build_dqn(self):
224224
q_summary = []
225225
avg_q = tf.reduce_mean(self.q, 0)
226226
for idx in xrange(self.env.action_size):
227-
q_summary.append(tf.histogram_summary('q/%s' % idx, avg_q[idx]))
228-
self.q_summary = tf.merge_summary(q_summary, 'q_summary')
227+
q_summary.append(tf.summary.histogram('q/%s' % idx, avg_q[idx]))
228+
self.q_summary = tf.summary.merge(q_summary, 'q_summary')
229229

230230
# target network
231231
with tf.variable_scope('target'):
@@ -312,15 +312,15 @@ def build_dqn(self):
312312

313313
for tag in scalar_summary_tags:
314314
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
315-
self.summary_ops[tag] = tf.scalar_summary("%s-%s/%s" % (self.env_name, self.env_type, tag), self.summary_placeholders[tag])
315+
self.summary_ops[tag] = tf.summary.scalar("%s-%s/%s" % (self.env_name, self.env_type, tag), self.summary_placeholders[tag])
316316

317317
histogram_summary_tags = ['episode.rewards', 'episode.actions']
318318

319319
for tag in histogram_summary_tags:
320320
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
321-
self.summary_ops[tag] = tf.histogram_summary(tag, self.summary_placeholders[tag])
321+
self.summary_ops[tag] = tf.summary.histogram(tag, self.summary_placeholders[tag])
322322

323-
self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)
323+
self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph)
324324

325325
tf.initialize_all_variables().run()
326326

‎dqn/environment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import cv2
21
import gym
32
import random
43
import numpy as np
4+
from utils import rgb2gray, imresize
55

66
class Environment(object):
77
def __init__(self, config):
@@ -40,7 +40,7 @@ def _random_step(self):
4040

4141
@ property
4242
def screen(self):
43-
return cv2.resize(cv2.cvtColor(self._screen, cv2.COLOR_RGB2GRAY)/255., self.dims)
43+
return imresize(rgb2gray(self._screen)/255., self.dims)
4444
#return cv2.resize(cv2.cvtColor(self._screen, cv2.COLOR_BGR2YCR_CB)/255., self.dims)[:,:,0]
4545

4646
@property

‎dqn/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
import time
22
import cPickle
3+
import numpy as np
34
import tensorflow as tf
45

6+
try:
7+
from scipy.misc import imresize
8+
except:
9+
import cv2
10+
imresize = cv2.resize
11+
12+
def rgb2gray(image):
13+
return np.dot(image[...,:3], [0.299, 0.587, 0.114])
14+
515
def timeit(f):
616
def timed(*args, **kwargs):
717
start_time = time.time()

0 commit comments

Comments
 (0)
Please sign in to comment.