import argparse import json import logging import os import cv2 from detectron2 import model_zoo from detectron2.config import get_cfg from detectron2.data import MetadataCatalog from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import Visualizer import numpy as np from pymongo import MongoClient # Setup logging logging.basicConfig(format='[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s') logger = logging.getLogger("Prediction") logger.setLevel(logging.DEBUG) def predict( category_json="", source="data/images/original", output_dir="data/images/predicted", image_id="", ): if not category_json: categories = [ "Baumbestand", "Festweg", "Pflaster", "Wiese", "Wasser", "Gullydeckel" ] else: with open(category_json, "r") as file: data = json.load(file) categories = data["categories"] logger.debug(f"Found categories: {categories}") metadata = MetadataCatalog.get("predict") metadata.set(thing_classes=categories) cfg = get_cfg() cfg.merge_from_file( model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(categories) cfg.MODEL.DEVICE = 'cpu' predictor = DefaultPredictor(cfg) if source.startswith("mongo://"): logger.info("MONGOOOOO") client = MongoClient(source[8:]) collection = client.get_database("stadtmg").get_collection("predictions") db_object = collection.find_one({"id": image_id}) logger.debug(db_object) image_binary = db_object.get("input") nparr = np.fromstring(image_binary, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) result_image = analyse_image(image, predictor, metadata) is_success, im_buf_arr = cv2.imencode(".png", result_image) byte_im = im_buf_arr.tobytes() collection.update_one({"id": image_id}, {"$set": { "output": byte_im }}) logger.info(collection.find_one({"id": image_id}).keys()) else: for path, dirs, files in os.walk(source): for file_name in files: image_path = os.path.join(path, file_name) logger.debug("Opening image") logger.info(f"Analysing image '{image_path}'") im = cv2.imread(image_path) result_image = analyse_image( im, predictor, metadata, ) out_path = os.path.join(output_dir, path[len(source):]) os.makedirs(out_path, exist_ok=True) cv2.imwrite(os.path.join(out_path, file_name), result_image[:, :, ::-1]) logger.info("Done.") def analyse_image( image, predictor, metadata, ): if image is None: return logger.debug("Predicting") outputs = predictor(image) instances = outputs["instances"] logger.debug(instances.pred_classes) logger.debug(instances.pred_boxes) logger.debug("Drawing masks") v = Visualizer(image[:, :, ::-1], metadata=metadata, scale=1.2) out = v.draw_instance_predictions(instances.to("cpu")) logger.debug("Saving image") result_image = out.get_image() return result_image def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument( "--category_json", type=str, default="", help="Path to a .json file containing the definition of categories", ) parser.add_argument( "--source", type=str, default="data/images/original", help="Path to images that should be detected" ) parser.add_argument( "--output_dir", type=str, default="data/images/predicted", help="Path to folder where the detected images will be placed" ) parser.add_argument( "--image_id", type=str, default="", help="ID of the object in MongoDB that contains the image" ) opt = parser.parse_args() return opt if __name__ == '__main__': logger.info("Hello") opt = parse_opt() logger.info(vars(opt)) predict(**vars(opt)) # if __name__ == '__main__': # predict( # image_dir="../data/images/tmp/westend/cropped", # output_dir="../data/images/predicted/westend", # )