参考:https://github.com/juandes/pikachu-detection/blob/master/detection_video.py
在之前的文章中,实现了利用tensorflow的目标检测API训练模型,并用图片来验证模型的有效性。本文的目的是为了将模型应用在视频检测中,实现视频流的实时检测。
—————————————2018.12.3更新———————————————-
抱歉前段时间一直在做语义分割的项目,没有时间测试目标检测的接口,今天终于抽空把视频流的检测做了,话不多说,直接上代码。代码部分主要参考官方给的
object_detection_tutorial.ipynb中的内容,视频处理采用opencv库。
# coding: utf-8
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops
if tf.__version__ < '1.4.0':
raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')
from utils import label_map_util
from utils import visualization_utils as vis_util
# # Model preparation
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = 'inference_graph_67886/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt')
NUM_CLASSES = 12
def detect_in_video():
# VideoWriter is the responsible of creating a copy of the video
# used for the detections but with the detections overlays. Keep in
# mind the frame size has to be the same as original video.
out = cv2.VideoWriter('test_images/20171206/0-0-0_result.avi', cv2.VideoWriter_fourcc(
'M', 'J', 'P', 'G'), 25, (1280, 1024))
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object
# was detected.
detection_boxes = detection_graph.get_tensor_by_name(
'detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class
# label.
detection_scores = detection_graph.get_tensor_by_name(
'detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name(
'detection_classes:0')
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
cap = cv2.VideoCapture('test_images/20171206/0-0-0.avi')
while(cap.isOpened()):
# Read the frame
ret, frame = cap.read()
# Recolor the frame. By default, OpenCV uses BGR color space.
# This short blog post explains this better:
# https://www.learnopencv.com/why-does-opencv-use-bgr-color-format/
color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_np_expanded = np.expand_dims(color_frame, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores,
detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
# note: perform the detections using a higher threshold
vis_util.visualize_boxes_and_labels_on_image_array(
color_frame,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8,
min_score_thresh=.20)
cv2.imshow('frame', color_frame)
output_rgb = cv2.cvtColor(color_frame, cv2.COLOR_RGB2BGR)
out.write(output_rgb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
out.release()
cap.release()
cv2.destroyAllWindows()
def main():
detect_in_video()
if __name__ =='__main__':
main()