MG-KI_Flaechenanalyse/source/train.py

55 lines
1.6 KiB
Python

import argparse
import json
import os
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.engine import DefaultTrainer
from detectron2.model_zoo import model_zoo
def train(
data_json="data/json/train_data.json",
):
with open(data_json, "r") as file:
data = json.load(file)
DatasetCatalog.register("train", lambda: data["train_images"])
DatasetCatalog.register("test", lambda: data["test_images"])
MetadataCatalog.get("train").set(thing_classes=data["categories"])
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train",)
cfg.DATASETS.TEST = ("test",)
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 10000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 6
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_json",
type=str,
default="data/json/train_data.json",
help="Path to a .json file containing the definition of train data",
)
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = parse_opt()
train(**vars(opt))