import argparse import json import os from detectron2.config import get_cfg from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.engine import DefaultTrainer from detectron2.model_zoo import model_zoo def train( data_json="data/json/train_data.json", ): with open(data_json, "r") as file: data = json.load(file) DatasetCatalog.register("train", lambda: data["train_images"]) DatasetCatalog.register("test", lambda: data["test_images"]) MetadataCatalog.get("train").set(thing_classes=data["categories"]) cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) cfg.DATASETS.TRAIN = ("train",) cfg.DATASETS.TEST = ("test",) cfg.DATALOADER.NUM_WORKERS = 4 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") cfg.SOLVER.IMS_PER_BATCH = 2 cfg.SOLVER.BASE_LR = 0.00025 cfg.SOLVER.MAX_ITER = 10000 cfg.SOLVER.STEPS = [] cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 6 os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=True) trainer.train() def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument( "--data_json", type=str, default="data/json/train_data.json", help="Path to a .json file containing the definition of train data", ) opt = parser.parse_args() return opt if __name__ == '__main__': opt = parse_opt() train(**vars(opt))