1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Tensorflow:tfrecord数据读取和保存

Tensorflow:tfrecord数据读取和保存

时间:2023-10-26 06:16:42

相关推荐

Tensorflow:tfrecord数据读取和保存

TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。

实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准[3][4]的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:

bytes_list: 可以存储string 和byte两种数据类型。

float_list: 可以存储float(float32)与double(float64) 两种数据类型 。

int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。

值的一提的是,TensorFlow 源码中到处可见.proto 的文件,且这些文件定义了TensorFlow重要的数据结构部分,且多种语言可直接使用这类数据,很强大。

在数据集较小时,我们会把数据全部加载到内存里方便快速导入,但当数据量超过内存大小时,就只能放在硬盘上来一点点读取,这时就不得不考虑数据的移动、读取、处理等速度。使用TFRecord就是为了提速和节约空间的。对于大型数据,对比其余数据格式,protocol buffer类型的数据优势很明显。

tf.Example的数据类型

tf.train.Feature 消息类型可以接受以下三种类型(请参阅 .proto 文件)。大多数其他通用类型也可以强制转换成下面的其中一种:

tf.train.BytesList(可强制转换自以下类型)

string

byte

tf.train.FloatList(可强制转换自以下类型)

float (float32)

double (float64)

tf.train.Int64List(可强制转换自以下类型)

bool

enum

int32

uint32

int64

uint64

[tf.Example 的数据类型]

-柚子皮-

写入和读取

Python 中写入TFRecord 文件

示例:

user_features = df.loc[:, cols_dict['user_dense_cols'] + cols_dict['user_sparse_cols']].values

goods_features = df.loc[:, cols_dict['goods_dense_cols'] + cols_dict['goods_sparse_cols']].values

labels = df.loc[:, cols_dict['label_cols']].values

writer = tf.python_io.TFRecordWriter(out_file_name)

for id, (user_feature, goods_feature, label) in enumerate(

zip(user_features, goods_features, labels)):

if id == 0:

print("user_feature.shape:", user_feature.shape, "dtype:", user_feature.dtype)

print("goods_feature.shape:", goods_feature.shape, "dtype:", goods_feature.dtype)

print("label.shape:", label.shape, "dtype:", label.dtype)

""" 2. 定义features """

# 将一个样例转化为Example Protocol Buffer, 并将所有的信息写入这个数据结构

example = tf.train.Example(

features=tf.train.Features(

feature={

'user_feature': tf.train.Feature(float_list=tf.train.FloatList(value=user_feature.tolist())),

'goods_feature': tf.train.Feature(float_list=tf.train.FloatList(value=goods_feature.tolist())),

'label': tf.train.Feature(float_list=tf.train.FloatList(value=label.tolist())),

}))

""" 3. 序列化,写入"""

serialized = example.SerializeToString() # 将一个Example写入TFRecord文件

writer.write(serialized)

Note:写入向量本身就是list时"value=vectors[i]"没有中括号[TFRecord + Dataset 进行数据的写入和读取]。如果是np.array类型,可能需要tolist()转换一下(否则可能总是1维的无法强转?)。

tf.data 读取 TFRecord 文件

示例:

# Create a description of the features.

feature_description = {

# 无法自动识别size;加size后变成3维,中间1维无法识别,后面加转换

'user_feature': tf.FixedLenSequenceFeature((), tf.float32, default_value=0.0, allow_missing=True),

'goods_feature': tf.FixedLenSequenceFeature((), tf.float32, default_value=0.0, allow_missing=True),

'label': tf.FixedLenFeature((), tf.float32, default_value=0.0),

'extra': tf.FixedLenSequenceFeature((), tf.float32, default_value=0.0, allow_missing=True)

}

def _parse_function(example_proto, is_test=True):

# Parse the input `tf.Example` proto using the dictionary above.

result = tf.io.parse_single_example(example_proto, feature_description)

print("result:", result)

uf_len, gf_len, lab_len, ext_len = args.features_size

result['user_feature'] = tf.reshape(result['user_feature'], [uf_len, ])

result['goods_feature'] = tf.reshape(result['goods_feature'], [gf_len, ])

result['label'] = tf.reshape(result['label'], [lab_len, ])

result['extra'] = tf.reshape(result['extra'], [ext_len, ])

return result

raw_dataset = tf.data.TFRecordDataset(self.input_file_names,buffer_size=args.tfrecord_bufsize)

print("raw_dataset shape:{}\n".format(raw_dataset))

dataset = raw_dataset.map(functools.partial(_parse_function, is_test=args.is_test))

dataset怎么用参考[Tensorflow:dataset数据读取]

Note:

如果事先知道shape,就用FixedLenFeature/FixedLenSequenceFeature,不知道就用VarLenFeature。[TF record笔记]

Instead of usingtf.io.FixedLenFeaturefor parsing an array, usetf.io.FixedLenSequenceFeature.(for TensorFlow 1, usetf.instead oftf.io.)[How to convert Float array/list to TFRecord?]

1 像上面的features,写入向量本身就是list时,直接读会丢失了shape信息,后面如果想用shape信息的话,还是需要强制转换tf.reshape,否则shape为None,后面不好用。不知道有没有其它好方法。

2直接在'user_feature': tf.FixedLenSequenceFeature((7), tf.float32, default_value=0.0, allow_missing=True)里指定shape,会出错:ValueError: Cannot reshape a tensor with 1 elements to shape [7] (7 elements) for 'ParseSingleExample/Reshape' (op: 'Reshape') with input shapes: [], [1] and with input tensors computed as partial shapes: input[1] = [7].

[tensorflow错误日志-tfrecord读取错误]

其它读取示例

[TFRecord - TensorFlow 官方推荐的数据格式]

[使用TFRecord存取多个数据]

list格式数据保存:value_poi[tfrecord格式的内容解析及样例]

如何检查 Tensorflow .tfrecord 文件的正确性?

tfrecord文件损坏或者出错时,会在dataset读取文件时报错:

1DataLossError (see above for traceback): truncated record at 234873870

2 ERROR:tensorflow:Exception in QueueRunner: corrupted record at 52284962154

3 DataLossError (see above for traceback): corrupted record at 52284962154

可能原因:1由于tf.data.TFRecordDataset里面的tf.gfile接口连接hdfs或者公有云不稳定,导致最终tfrecord数据无法读取到。[公有云运行TensorFlow训练作业出现错误:truncated record at]

2数据生成时出错(写文件时服务器中断打开的文件)

3 上传到hdfs上时有损坏(或者也可能是1导致的)java.io.IOException: Failed on local exception: java.io.InterruptedIOException: Interrupted while waiting for IO on channel java.nio.channels.SocketChannel

解决:1的话重试? 2 3的话,可能需要删除出错文件,重新上传或者重新生成。

示例:a solution for the rare corruption that happens during the training.todo待测试

dataset = tf.data.Dataset.range(10)

dataset = dataset.apply(tf.data.experimental.ignore_errors())

iterator = dataset.make_one_shot_iterator()

[DataLossError: truncated record at 40109222]

示例:捕获DataLossError错误,但是不起作用,不知道是不是需要在sess.run()的时候捕获todo。

if_not_get_next_suc = True

while if_not_get_next_suc:

try:

self.raw_inputs = self.dataset_iterator.get_next()

if_not_get_next_suc = False

except tf.errors.DataLossError as e:

print(traceback.format_exc())

print("Error: skip data loss error!!!")

[TensorFlow问题处理:DataLossError: corrupted record at XXX]

示例:检查文件是否正确

tffilenames = []

for root, _, fs in os.walk(local_tfrecord_dir):

tffilenames += [os.path.join(root, f) for f in fs]

for fid, tffilename in enumerate(tffilenames):

try:

for id, example in enumerate(tf.python_io.tf_record_iterator(tffilename)):

r = tf.train.Example.FromString(example)

if id == 0:

print(r)

except Exception as e:

print("bad tffilename:", tffilename)

# print(traceback.format_exc())

print(e)

示例2:检查某个tfrecord文件的所有key

a.py:

import tensorflow as tf

print(tf.train.Example.FromString(list(tf.python_io.tf_record_iterator('train_data_tfrecord/data.tfrecords_0007'))[0]))

python a.py | grep key

python a.py | grep key | wc -l

python a.py | grep list # 查看key对应的数据类型

[python - 如何检查 Tensorflow .tfrecord 文件?]

[TF corrupted record while training]

from:-柚子皮-

ref:官方文档[TFRecord 和 tf.Example]

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。