エンコードに使用された機能が不明なTFRecordファイルの読み取り

Aug 24 2020

私はTensorFlowを初めて使用しますが、これは非常に初心者の質問かもしれません。使用したい機能(たとえば、「画像」、「ラベル」)の知識を使用して、カスタムデータセットがTFRecordファイルに変換される例を見てきました。また、このTFRecordファイルを解析し直すときに、このデータセットを使用できるようにするには、事前に機能(つまり、「画像」、「ラベル」)を知っている必要があります。

私の質問は、機能が事前にわからないTFRecordファイルをどのように解析するのかということです。誰かが私にTFRecordファイルをくれて、これに関連するすべての機能をデコードしたいとします。

私が言及しているいくつかの例は次のとおりです:リンク1、リンク2

回答

jdehesa Aug 25 2020 at 01:19

これが役立つかもしれない何かです。これは、レコードファイルを調べて、機能に関する利用可能な情報を保存する関数です。最初のレコードだけを見てその情報を返すように変更できますが、場合によっては、一部のまたは可変サイズの機能にのみ存在するオプション機能がある場合は、すべてのレコードを表示すると便利な場合があります。

import tensorflow as tf

def list_record_features(tfrecords_path):
    # Dict of extracted feature information
    features = {}
    # Iterate records
    for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
        # Get record bytes
        example_bytes = rec.numpy()
        # Parse example protobuf message
        example = tf.train.Example()
        example.ParseFromString(example_bytes)
        # Iterate example features
        for key, value in example.features.feature.items():
            # Kind of data in the feature
            kind = value.WhichOneof('kind')
            # Size of data in the feature
            size = len(getattr(value, kind).value)
            # Check if feature was seen before
            if key in features:
                # Check if values match, use None otherwise
                kind2, size2 = features[key]
                if kind != kind2:
                    kind = None
                if size != size2:
                    size = None
            # Save feature data
            features[key] = (kind, size)
    return features

あなたはそれをこのように使うことができます

import tensorflow as tf

tfrecords_path = 'data.tfrecord'
# Make some test records
with tf.io.TFRecordWriter(tfrecords_path) as writer:
    for i in range(10):
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    # Fixed length
                    'id': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=[i])),
                    # Variable length
                    'data': tf.train.Feature(
                        float_list=tf.train.FloatList(value=range(i))),
                }))
        writer.write(example.SerializeToString())
# Print extracted feature information
features = list_record_features(tfrecords_path)
print(*features.items(), sep='\n')
# ('id', ('int64_list', 1))
# ('data', ('float_list', None))