158 lines
4.5 KiB
Python
158 lines
4.5 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import cv2
|
|
from detectron2 import model_zoo
|
|
from detectron2.config import get_cfg
|
|
from detectron2.data import MetadataCatalog
|
|
from detectron2.engine import DefaultPredictor
|
|
from detectron2.utils.visualizer import Visualizer
|
|
import numpy as np
|
|
from pymongo import MongoClient
|
|
|
|
# Setup logging
|
|
logging.basicConfig(format='[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s')
|
|
logger = logging.getLogger("Prediction")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
def predict(
|
|
category_json="",
|
|
source="data/images/original",
|
|
output_dir="data/images/predicted",
|
|
image_id="",
|
|
):
|
|
if not category_json:
|
|
categories = [
|
|
"Baumbestand",
|
|
"Festweg",
|
|
"Pflaster",
|
|
"Wiese",
|
|
"Wasser",
|
|
"Gullydeckel"
|
|
]
|
|
else:
|
|
with open(category_json, "r") as file:
|
|
data = json.load(file)
|
|
categories = data["categories"]
|
|
logger.debug(f"Found categories: {categories}")
|
|
|
|
metadata = MetadataCatalog.get("predict")
|
|
metadata.set(thing_classes=categories)
|
|
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(
|
|
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
|
|
|
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
|
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(categories)
|
|
cfg.MODEL.DEVICE = 'cpu'
|
|
predictor = DefaultPredictor(cfg)
|
|
|
|
if source.startswith("mongo://"):
|
|
logger.info("MONGOOOOO")
|
|
client = MongoClient(source[8:])
|
|
collection = client.get_database("stadtmg").get_collection("predictions")
|
|
|
|
db_object = collection.find_one({"id": image_id})
|
|
logger.debug(db_object)
|
|
image_binary = db_object.get("input")
|
|
nparr = np.fromstring(image_binary, np.uint8)
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
result_image = analyse_image(image, predictor, metadata)
|
|
|
|
is_success, im_buf_arr = cv2.imencode(".png", result_image)
|
|
byte_im = im_buf_arr.tobytes()
|
|
|
|
collection.update_one({"id": image_id}, {"$set": {
|
|
"output": byte_im
|
|
}})
|
|
logger.info(collection.find_one({"id": image_id}).keys())
|
|
|
|
else:
|
|
for path, dirs, files in os.walk(source):
|
|
for file_name in files:
|
|
image_path = os.path.join(path, file_name)
|
|
logger.debug("Opening image")
|
|
|
|
logger.info(f"Analysing image '{image_path}'")
|
|
im = cv2.imread(image_path)
|
|
result_image = analyse_image(
|
|
im,
|
|
predictor,
|
|
metadata,
|
|
)
|
|
out_path = os.path.join(output_dir, path[len(source):])
|
|
os.makedirs(out_path, exist_ok=True)
|
|
cv2.imwrite(os.path.join(out_path, file_name), result_image[:, :, ::-1])
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
def analyse_image(
|
|
image,
|
|
predictor,
|
|
metadata,
|
|
):
|
|
if image is None:
|
|
return
|
|
|
|
logger.debug("Predicting")
|
|
outputs = predictor(image)
|
|
instances = outputs["instances"]
|
|
logger.debug(instances.pred_classes)
|
|
logger.debug(instances.pred_boxes)
|
|
|
|
logger.debug("Drawing masks")
|
|
v = Visualizer(image[:, :, ::-1], metadata=metadata, scale=1.2)
|
|
out = v.draw_instance_predictions(instances.to("cpu"))
|
|
logger.debug("Saving image")
|
|
result_image = out.get_image()
|
|
return result_image
|
|
|
|
|
|
def parse_opt():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--category_json",
|
|
type=str,
|
|
default="",
|
|
help="Path to a .json file containing the definition of categories",
|
|
)
|
|
parser.add_argument(
|
|
"--source",
|
|
type=str,
|
|
default="data/images/original",
|
|
help="Path to images that should be detected"
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="data/images/predicted",
|
|
help="Path to folder where the detected images will be placed"
|
|
)
|
|
parser.add_argument(
|
|
"--image_id",
|
|
type=str,
|
|
default="",
|
|
help="ID of the object in MongoDB that contains the image"
|
|
)
|
|
opt = parser.parse_args()
|
|
return opt
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logger.info("Hello")
|
|
opt = parse_opt()
|
|
logger.info(vars(opt))
|
|
predict(**vars(opt))
|
|
|
|
# if __name__ == '__main__':
|
|
# predict(
|
|
# image_dir="../data/images/tmp/westend/cropped",
|
|
# output_dir="../data/images/predicted/westend",
|
|
# )
|