鹿児島ハードチル同好会

情報学部の大学生です。深層学習(Tensorflow)とかUbuntuとか音楽とかガジェットに興味があります。バンドもしてたりする

TensorflowのEmbedding Visualizationでカッコよく可視化したい

TensorBoardがカッコいい

大学でディープラーニングを用いた研究をやろうと意気込んでおります。 TensorBoardでカッコよく可視化してみたいと思ったところ、アイドル顔認識の先駆者である偉大なすぎゃーんさんのこのツイートを発見

memo.sugyan.com

ご丁寧に記事まで書いてくださっており、動かそうと思ったのですがTensorflowのバージョンが1.0系に変わってからの大幅な仕様変更でいろいろと問題があったので躓いたところを簡単にまとめました。

実行環境

学習済みモデルとデモのダウンロード

Tensorflowのページを参考に学習済みのモデルをダウンロード

Image Recognition  |  TensorFlow

ターミナルを開いて

 $ git clone https://github.com/tensorflow/models

にてクローリングします。

学習済みのモデルのディレクトリが tensorflow.models.image.imagenet から models/tutorials/image/imagenet へと変更されていました

同様にすぎゃーん氏のデモもクローリングします

 $ git clone https://github.com/sugyan/tf-embedding-visualization-demo/blob/master/README.md

デモディレクトリ内にmodels/tutorials/image/imagenetにある classify_image.pyをコピーしておきます

以下修正

このまますぐ実行できれば良いのですがそうなるとこの記事の存在意義がなくなってしまいます。 以下修正箇所です。 コメントアウトが修正前、その下が修正後のコードです

# main.py

# line 5
# from tensorflow.models.image.imagenet import classify_image
import classify_image


# line 35
# embedding_var = tf.Variable(tf.pack([tf.squeeze(x) for x in outputs], axis=0), trainable=False, name='pool3')
embedding_var = tf.Variable(tf.stack([tf.squeeze(x) for x in outputs], axis=0), trainable=False, name='pool3')


# line 40
# summary_writer = tf.train.SummaryWriter(os.path.join(basedir, 'logdir'))
summary_writer = tf.summary.FileWriter(os.path.join(basedir, 'logdir'))

# line 53 ~ 55
# for i in range(size):
#     rows.append(tf.concat(1, images[i*size:(i+1)*size]))
# jpeg = tf.image.encode_jpeg(tf.concat(0, rows))

for i in range(size):
    rows.append(tf.concat(images[i*size:(i+1)*size], 1))
jpeg = tf.image.encode_jpeg(tf.concat(rows, 0))

恐らくこれで動くかと思われます。 Tensorflowが1.0系になってからとそれ以前の情報で溢れているインターネットでこれからも頑張りたいという気持ちのもと、うどんを眺めています

参考ページ

TensorFlow の "AttributeError: 'module' object has no attribute 'xxxx'" エラーでつまづいてしまう人のための移行ガイド - Qiita

TensorFlow v1.1 / 移行 > tf.pack()はtf.stack()になった - Qiita

TypeError: Expected int32, got <prettytensor.pretty_tensor_class.Layer > of type 'Layer' instead. · Issue #48 · google/prettytensor · GitHub