55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
|
|
from detectron2.config import get_cfg
|
|
from detectron2.data import DatasetCatalog, MetadataCatalog
|
|
from detectron2.engine import DefaultTrainer
|
|
from detectron2.model_zoo import model_zoo
|
|
|
|
|
|
def train(
|
|
data_json="data/json/train_data.json",
|
|
):
|
|
with open(data_json, "r") as file:
|
|
data = json.load(file)
|
|
|
|
DatasetCatalog.register("train", lambda: data["train_images"])
|
|
DatasetCatalog.register("test", lambda: data["test_images"])
|
|
MetadataCatalog.get("train").set(thing_classes=data["categories"])
|
|
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
|
cfg.DATASETS.TRAIN = ("train",)
|
|
cfg.DATASETS.TEST = ("test",)
|
|
cfg.DATALOADER.NUM_WORKERS = 4
|
|
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
|
cfg.SOLVER.IMS_PER_BATCH = 2
|
|
cfg.SOLVER.BASE_LR = 0.00025
|
|
cfg.SOLVER.MAX_ITER = 10000
|
|
cfg.SOLVER.STEPS = []
|
|
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
|
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 6
|
|
|
|
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
|
|
trainer = DefaultTrainer(cfg)
|
|
trainer.resume_or_load(resume=True)
|
|
trainer.train()
|
|
|
|
|
|
def parse_opt():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--data_json",
|
|
type=str,
|
|
default="data/json/train_data.json",
|
|
help="Path to a .json file containing the definition of train data",
|
|
)
|
|
opt = parser.parse_args()
|
|
return opt
|
|
|
|
|
|
if __name__ == '__main__':
|
|
opt = parse_opt()
|
|
train(**vars(opt))
|