java中的接口是类吗
279
2022-08-26
python tensorflow框架(python是什么意思)
二进制读取案例
import tensorflow as tfimport osos.environ['TF_CPP_MIN_LOG_LEVEL']='2'class Cifar(object): def __init__(self): self.height = 32 self.weight = 32 self.channels = 3 #图像像素 self.image_bytes = self.height * self.weight * self.channels #图像的标签 self.label_bytes = 1 #一个样本 self.all_bytes = self.image_bytes + self.label_bytes def read_and_decoded(self,file_list): # 1.构建文件队列 file_queue = tf.train.string_input_producer(file_list) # 2.读取与解码 #读取 reader = tf.FixedLengthRecordReader(self.all_bytes) key,value = reader.read(file_queue) print(key) print(value) #解码 decoded_value = tf.decode_raw(value, tf.uint8) print(decoded_value) #目标值切片 label = tf.slice(decoded_value, [0], [self.label_bytes]) image = tf.slice(decoded_value, [self.label_bytes], [self.image_bytes]) print(label) print(image) #恢复张量shape,先channels、height、weight image_reshape = tf.reshape(image, shape=[self.channels, self.height, self.weight]) #装置,将原本读取的channels、height、width--装置为tensorflow支持的排列-----> height、width、channels image_transpose = tf.transpose(image_reshape, [1, 2, 0]) print(image_transpose) #调整图像类型,方便矩阵计算 image_casted = tf.cast(image_transpose, dtype=tf.float32) print(image_casted) # 3.批处理 label_batch, image_batch = tf.train.batch([label, image_casted], batch_size=100, num_threads=1, capacity=100) print(label_batch) print(image_batch) with tf.Session() as sess: #开启线程 #线程协调 coords = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coords) new_key,\\ new_value, \\ new_decoded_value, \\ new_label, \\ new_image, \\ new_image_reshape,\\ new_image_transpose,\\ new_image_casted,\\ new_label_batch,\\ new_image_batch= sess.run([key, value, decoded_value, label, image, image_reshape, image_transpose, image_casted, label_batch, image_batch]) print(new_key) print(new_value) print(new_decoded_value) print(new_label) print(new_image) print(new_image_reshape) coords.request_stop() coords.join(threads) print(new_image_transpose) print(new_image_casted) print(new_label_batch) print(new_image_batch) return Noneif __name__ == '__main__': filename = os.listdir("./datasources/datasets/cifar-10-batches-py") file_list = [os.path.join("./datasources/datasets/cifar-10-batches-py/", file) for file in filename ] cifar = Cifar() cifar.read_and_decoded(file_list)
存储——TFRecords
是一种二进制文件,能够更好的利用内存,根方便复制和移动,不需要单独的标签文件使用步骤:
获取数据将数据填入到Example协议内存块(protocol buffter)将协议内存块序列化为字符串,通过tf.python_io.TFRecordWriter写入到TFRecords文件
文件格式*.tfrecords
Example内部结构
options具体要看值的类型
例子:
""" 序列化数据,使用TFRecords文件存储 """ def save_to_tfrecord(self, image_batch, label_batch ): with tf.python_io.TFRecordWriter("cifar.tfrecords") as wirter: #因为有100个样本 for i in range(100): image = image_batch[i].tostring() label = label_batch[i][0] example = tf.train.Example(features = tf.train.Features(feature = { "image":tf.train.Feature(bytes_list = tf.train.BytesList(value=[image])), "label":tf.train.Feature(int64_list = tf.train.Int64List(value=[label])), })) # 将序列化后的example写入文件 wirter.write(example.SerializeToString()) return Noneif __name__ == '__main__': filename = os.listdir("./datasources/datasets/cifar-10-batches-py") file_list = [os.path.join("./datasources/datasets/cifar-10-batches-py/", file) for file in filename ] cifar = Cifar() image, label = cifar.read_and_decoded(file_list) cifar.save_to_tfrecord(image, label)
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~