垃圾分类(加入增强学习和通道机制)
import numpy as npimport matplotlib.pyplot as pltfrom keras.preprocessing.image import ImageDataGenerator,load_img,img_to_arrayfrom keras.layers import Conv2D,Flatten,MaxPooling2D,Densefrom keras.models import Sequential,load_modelimport glob,os,randomimport timeimport kerasbase_path = "datasets"def look_dataset_num(): img_list = glob.glob(os.path.join(base_path, "*/*.jpg")) print(len(img_list)) # 2307 # 随机查看数据,枚举 for i, img_path in enumerate(random.sample(img_list, 6)): img = load_img(img_path) img = img_to_array(img, dtype=np.uint8) # 子图 plt.subplot(2, 3, i + 1) plt.imshow(img.squeeze()) plt.show()def crate_model(): start = time.time() train_datagen = ImageDataGenerator( rescale=1. / 225, shear_range=0.1, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, vertical_flip=True, validation_split=0.1) test_datagen = ImageDataGenerator( rescale=1. / 255, validation_split=0.1) train_generator = train_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='training', seed=0) # Found 2276 images belonging to 6 classes. validation_generator = test_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='validation', seed=0) # Found 251 images belonging to 6 classes. a = (validation_generator.class_indices) a = dict((v, k) for k, v in a.items()) labels = (train_generator.class_indices) labels = dict((v, k) for k, v in labels.items()) print('train_datagen ', a) # train_datagen {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'} print('test_datagen', train_datagen) # test_datagen print('labels', labels) # labels {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'} # 4.模型的建立和训练 model = Sequential([ Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)), MaxPooling2D(pool_size=2), Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)), MaxPooling2D(pool_size=2), Flatten(), Dense(64, activation='relu'), Dense(6, activation='softmax') ]) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc']) model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276 // 32, validation_data=validation_generator, validation_steps=251 // 32) model.save('rubbish/rubbish_model.h5') # # 5.结果展示 # 下面我们随机抽取validation中的16张图片,展示图片以及其标签,并且给予我们的预测。 # 我们发现预测的准确度还是蛮高的,对于大部分图片,都能识别出其类别。 test_x, test_y = validation_generator.__getitem__(1) preds = model.predict(test_x) plt.figure(figsize=(16, 16)) for i in range(16): plt.subplot(4, 4, i + 1) plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])])) plt.imshow(test_x[i]) plt.show() end = time.time() t = end - start print('运行time', t)def use_model(): train_datagen = ImageDataGenerator( rescale=1. / 225, shear_range=0.1, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, vertical_flip=True, validation_split=0.1) test_datagen = ImageDataGenerator( rescale=1. / 255, validation_split=0.1) train_generator = train_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=36, class_mode='categorical', subset='training', seed=0) # Found 2276 images belonging to 6 classes. validation_generator = test_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=36, class_mode='categorical', subset='validation', seed=0) a = (validation_generator.class_indices) labels = (train_generator.class_indices) labels = dict((v, k) for k, v in labels.items()) model = load_model('rubbish/rubbish_model.h5') test_x, test_y = validation_generator.__getitem__(1) print(test_x) preds = model.predict(test_x) plt.figure(figsize=(36, 36)) for i in range(36): plt.subplot(6, 6, i + 1) plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])])) plt.imshow(test_x[i]) plt.show()if __name__ == '__main__': # look_dataset_num() # crate_model() use_model()
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
暂时没有评论,来抢沙发吧~