2017年6月26日月曜日

TensorFlowのRNNのチュートリアル、translateの動作確認

概要

Sequence-to-Sequence Models | TensorFlowに記された通り、TensorFlow main repoとTensorFlow models repoをcloneした。続けてmodels/tutorials/rnn/translate/のtranslate.pyを実行しコーパスのダウンロード後、訓練を行った。3日前後*1でbucket 0、bucket 1のperplexity*2が1桁となったため、--decode引数を付加しtranslate.pyを実行した所、英文の仏訳が得られた。


訓練開始時引数

上掲TensorFlowチュートリアル内 Let's run it 章で紹介された通りのパラメータを指定し、訓練を開始した。

$ python translate.py --data_dir data --train_dir checkpoints --size=256 --num_layers=2 --steps_per_checkpoint=50
Preparing WMT data in data
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
Creating 2 layers of 256 units.
Created model with fresh parameters.
Reading development and training data (limit: 0).
  reading data line 100000
(後略)

SSE4.1やAVX命令への非対応を示す警告が表示されているが、対応するにはTensorFlowのコンパイルに時間がかかるため、今回は無視した。

perplexityがsingle digitsに

三日三晩訓練を続け Let's run it に書かれた通り、perplexityが1桁になるまで待った。

global step 5000 learning rate 0.4477 step-time 6.66 perplexity 44.09
  eval: bucket 0 perplexity 26.66
  eval: bucket 1 perplexity 27.62
  eval: bucket 2 perplexity 38.82
  eval: bucket 3 perplexity 41.83
(中略)
global step 10000 learning rate 0.4173 step-time 6.64 perplexity 17.19
  eval: bucket 0 perplexity 14.07
  eval: bucket 1 perplexity 15.41
  eval: bucket 2 perplexity 16.36
  eval: bucket 3 perplexity 18.33
(中略)
global step 25250 learning rate 0.2377 step-time 7.41 perplexity 8.76
  eval: bucket 0 perplexity 8.25
  eval: bucket 1 perplexity 8.48
  eval: bucket 2 perplexity 8.10
  eval: bucket 3 perplexity 9.54

bucket 0のperplexityが8.25、bucket 1は8.48まで減少した。

decode結果

$ python translate.py --decode --data_dir data --train_dir checkpoints --size=256 --num_layers=2
Reading model parameters from checkpoints/translate.ckpt-25250
> Who is the president of the United States?
Qui est la présidence des États-Unis ?

英文 "Who is the president of the United States?"が仏訳された。

注意点

2017年6月23日時点で、このコードはTensorFlow バージョン1.2に対応していないので、下記のコマンド等でバージョン1.0を指定しインストールする。


$ pip install tensorflow==1.0
decode時に訓練時に指定した引数、--size=256 --num_layers=2を忘れると、下記エラーとなるので注意。
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [2048,2048] rhs shape= [512,512]
  [[Node: save/Assign_31 = Assign[T=DT_FLOAT, _class=["loc:@embedding_attention_seq2seq/rnn/multi_rnn_cell/cell_1/gru_cell/gates/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](embedding_attention_seq2seq/rnn/multi_rnn_cell/cell_1/gru_cell/gates/weights, save/RestoreV2_31)]]
NotFoundError (see above for traceback): Key embedding_attention_seq2seq/rnn/multi_rnn_cell/cell_2/gru_cell/candidate/biases not found in checkpoint
  [[Node: save/RestoreV2_32 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_32/tensor_names, save/RestoreV2_32/shape_and_slices)]]

0 件のコメント:

コメントを投稿