MG-KI_Flaechenanalyse/source/predict.py

158 lines
4.5 KiB
Python

import argparse
import json
import logging
import os
import cv2
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
import numpy as np
from pymongo import MongoClient
# Setup logging
logging.basicConfig(format='[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s')
logger = logging.getLogger("Prediction")
logger.setLevel(logging.DEBUG)
def predict(
category_json="",
source="data/images/original",
output_dir="data/images/predicted",
image_id="",
):
if not category_json:
categories = [
"Baumbestand",
"Festweg",
"Pflaster",
"Wiese",
"Wasser",
"Gullydeckel"
]
else:
with open(category_json, "r") as file:
data = json.load(file)
categories = data["categories"]
logger.debug(f"Found categories: {categories}")
metadata = MetadataCatalog.get("predict")
metadata.set(thing_classes=categories)
cfg = get_cfg()
cfg.merge_from_file(
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(categories)
cfg.MODEL.DEVICE = 'cpu'
predictor = DefaultPredictor(cfg)
if source.startswith("mongo://"):
logger.info("MONGOOOOO")
client = MongoClient(source[8:])
collection = client.get_database("stadtmg").get_collection("predictions")
db_object = collection.find_one({"id": image_id})
logger.debug(db_object)
image_binary = db_object.get("input")
nparr = np.fromstring(image_binary, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
result_image = analyse_image(image, predictor, metadata)
is_success, im_buf_arr = cv2.imencode(".png", result_image)
byte_im = im_buf_arr.tobytes()
collection.update_one({"id": image_id}, {"$set": {
"output": byte_im
}})
logger.info(collection.find_one({"id": image_id}).keys())
else:
for path, dirs, files in os.walk(source):
for file_name in files:
image_path = os.path.join(path, file_name)
logger.debug("Opening image")
logger.info(f"Analysing image '{image_path}'")
im = cv2.imread(image_path)
result_image = analyse_image(
im,
predictor,
metadata,
)
out_path = os.path.join(output_dir, path[len(source):])
os.makedirs(out_path, exist_ok=True)
cv2.imwrite(os.path.join(out_path, file_name), result_image[:, :, ::-1])
logger.info("Done.")
def analyse_image(
image,
predictor,
metadata,
):
if image is None:
return
logger.debug("Predicting")
outputs = predictor(image)
instances = outputs["instances"]
logger.debug(instances.pred_classes)
logger.debug(instances.pred_boxes)
logger.debug("Drawing masks")
v = Visualizer(image[:, :, ::-1], metadata=metadata, scale=1.2)
out = v.draw_instance_predictions(instances.to("cpu"))
logger.debug("Saving image")
result_image = out.get_image()
return result_image
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
"--category_json",
type=str,
default="",
help="Path to a .json file containing the definition of categories",
)
parser.add_argument(
"--source",
type=str,
default="data/images/original",
help="Path to images that should be detected"
)
parser.add_argument(
"--output_dir",
type=str,
default="data/images/predicted",
help="Path to folder where the detected images will be placed"
)
parser.add_argument(
"--image_id",
type=str,
default="",
help="ID of the object in MongoDB that contains the image"
)
opt = parser.parse_args()
return opt
if __name__ == '__main__':
logger.info("Hello")
opt = parse_opt()
logger.info(vars(opt))
predict(**vars(opt))
# if __name__ == '__main__':
# predict(
# image_dir="../data/images/tmp/westend/cropped",
# output_dir="../data/images/predicted/westend",
# )