• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

TensorFlow之视频流实时目标检测

人工智能 宗孝鹏 2090次浏览 0个评论

参考:https://github.com/juandes/pikachu-detection/blob/master/detection_video.py

之前的文章中,实现了利用tensorflow的目标检测API训练模型,并用图片来验证模型的有效性。本文的目的是为了将模型应用在视频检测中,实现视频流的实时检测。

—————————————2018.12.3更新———————————————-

抱歉前段时间一直在做语义分割的项目,没有时间测试目标检测的接口,今天终于抽空把视频流的检测做了,话不多说,直接上代码。代码部分主要参考官方给的

object_detection_tutorial.ipynb中的内容,视频处理采用opencv库。

# coding: utf-8
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
 
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
 
import cv2
 
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops
 
if tf.__version__ < '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')
 
from utils import label_map_util
from utils import visualization_utils as vis_util
 
# # Model preparation 
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = 'inference_graph_67886/frozen_inference_graph.pb'
 
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt')
 
NUM_CLASSES = 12
 
def detect_in_video():
    # VideoWriter is the responsible of creating a copy of the video
    # used for the detections but with the detections overlays. Keep in
    # mind the frame size has to be the same as original video.
    out = cv2.VideoWriter('test_images/20171206/0-0-0_result.avi', cv2.VideoWriter_fourcc(
        'M', 'J', 'P', 'G'), 25, (1280, 1024))
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
 
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(
        label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)
 
    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            # Definite input and output Tensors for detection_graph
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object
            # was detected.
            detection_boxes = detection_graph.get_tensor_by_name(
                'detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class
            # label.
            detection_scores = detection_graph.get_tensor_by_name(
                'detection_scores:0')
            detection_classes = detection_graph.get_tensor_by_name(
                'detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name(
                'num_detections:0')
            cap = cv2.VideoCapture('test_images/20171206/0-0-0.avi')
 
            while(cap.isOpened()):
                # Read the frame
                ret, frame = cap.read()
 
                # Recolor the frame. By default, OpenCV uses BGR color space.
                # This short blog post explains this better:
                # https://www.learnopencv.com/why-does-opencv-use-bgr-color-format/
                color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
                image_np_expanded = np.expand_dims(color_frame, axis=0)
 
                # Actual detection.
                (boxes, scores, classes, num) = sess.run(
                    [detection_boxes, detection_scores,
                        detection_classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
 
                # Visualization of the results of a detection.
                # note: perform the detections using a higher threshold
                vis_util.visualize_boxes_and_labels_on_image_array(
                    color_frame,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8,
                    min_score_thresh=.20)
 
                cv2.imshow('frame', color_frame)
                output_rgb = cv2.cvtColor(color_frame, cv2.COLOR_RGB2BGR)
                out.write(output_rgb)
 
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            out.release()
            cap.release()
            cv2.destroyAllWindows()
 
def main():
    detect_in_video()
 
if __name__ =='__main__':
    main()
 


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明TensorFlow之视频流实时目标检测
喜欢 (0)

您必须 登录 才能发表评论!

加载中……