TensorFlowのデータフォーマットTFRecordを読み込む

TensorFlowの保存形式TFRecordの中身を読み込んで見る

TensorFlowにはTFRecordというバイナリ保存形式があるのですが、それを読み込んでみました。

'./tf_records/wiki_aa.tf_record'というTFRecordファイルから中身を読み出してみます。

import tensorflow as tf
import sys
import tokenization

input_file = './tf_records/wiki_aa.tf_record'

tokenizer = tokenization.FullTokenizer(model_file='new.model', vocab_file='new.vocab')
count =0
for example in tf.io.tf_record_iterator(input_file):
    result = tf.train.Example.FromString(example)
    print(result.features.feature["input_ids"])
    print(type(result))
    
    token_ids = list(result.features.feature["input_ids"].int64_list.value)
    
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    for i, id in enumerate(token_ids):
        print("{:<10} {:10}".format(id, tokens[i]))

    print(len(tokens))
    print(''.join(tokens))

    if count == 1:  
        sys.exit()
    count += 1

tf.io.tf_record_iterator(input_file)でTFRecord形式のファイルを読み込んで、イテレータオブジェクトを返します。

exampleに一つのデータ・セットが入ります。

print(result.features.feature["input_ids"])の結果は

int64_list {
value: 5
value: 4956
value: 17
value: 49
# 省略
value: 0
value: 0
value: 0
value: 0
}

 

 

コメント