问题描述
InvalidArgumentError: Key: 'label’. Can’t parse serialized Example.
保存tfrecord
def save_tfrecords(data, label, desfile):assert data.shape[0] == label.shape[0]with tf.python_io.TFRecordWriter(desfile) as writer:for i in range(data.shape[0]):features = tf.train.Features(feature={"data": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data[i, :].astype(np.float32).tostring()])),"label": tf.train.Feature(int64_list=tf.train.Int64List(value=label[i]))})example = tf.train.Example(features=features)serialized = example.SerializeToString()writer.write(serialized)
加载tfrecord
def load_tfrecords(file):def _parse_function(example_proto):features = {"data": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64)}parsed_features = tf.parse_single_example(example_proto, features)data = tf.decode_raw(parsed_features['data'], tf.float32)return data, parsed_features["label"]sess = tf.Session()# load tfrecord filedataset = tf.data.TFRecordDataset(file)# parse data into tensordataset = dataset.map(_parse_function)dataset = dataset.repeat(2) # repeat for 2 epochesdataset = dataset.batch(1) # set batch_size = 5iterator = dataset.make_one_shot_iterator()next_data = iterator.get_next()while True:try:data, label = sess.run(next_data)print(data)print(label)except tf.errors.OutOfRangeError:print("End of dataset")break
label:
label = [[0 1], [0,1], [1,0], [1,0]]
Stack overflow 同样的问题:
/questions/53499409/tensorflow-tfrecord-cant-parse-serialized-example
错误原因
tf.FixedLenFeature() is used for reading the fixed size arrays of data. And the shape of the data should be defined beforehand.
tf.FixedLenFeature() 用于读取固定大小的数组数据。
数据的shape应该提前定义。
解决办法
指定label的shape。
我这里label = [[0 1], [0,1], [1,0], [1,0]]
一次读取一个,size为2
所以
features = {"data": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature([2], tf.int64)}