| |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| from omegaconf import OmegaConf, DictConfig |
|
|
| from .. import logger |
| from ..conf import data as conf_data_dir |
| from ..data import MapillaryDataModule |
| from .run import evaluate |
|
|
|
|
| split_overrides = { |
| "val": { |
| "scenes": [ |
| "sanfrancisco_soma", |
| "sanfrancisco_hayes", |
| "amsterdam", |
| "berlin", |
| "lemans", |
| "montrouge", |
| "toulouse", |
| "nantes", |
| "vilnius", |
| "avignon", |
| "helsinki", |
| "milan", |
| "paris", |
| ], |
| }, |
| } |
| data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml") |
| data_cfg = OmegaConf.merge( |
| data_cfg_train, |
| { |
| "return_gps": True, |
| "add_map_mask": True, |
| "max_init_error": 32, |
| "loading": {"val": {"batch_size": 1, "num_workers": 0}}, |
| }, |
| ) |
| default_cfg_single = OmegaConf.create({"data": data_cfg}) |
| default_cfg_sequential = OmegaConf.create( |
| { |
| **default_cfg_single, |
| "chunking": { |
| "max_length": 10, |
| }, |
| } |
| ) |
|
|
|
|
| def run( |
| split: str, |
| experiment: str, |
| cfg: Optional[DictConfig] = None, |
| sequential: bool = False, |
| thresholds: Tuple[int] = (1, 3, 5), |
| **kwargs, |
| ): |
| cfg = cfg or {} |
| if isinstance(cfg, dict): |
| cfg = OmegaConf.create(cfg) |
| default = default_cfg_sequential if sequential else default_cfg_single |
| default = OmegaConf.merge(default, split_overrides[split]) |
| cfg = OmegaConf.merge(default, cfg) |
| dataset = MapillaryDataModule(cfg.get("data", {})) |
|
|
| metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs) |
|
|
| keys = [ |
| "xy_max_error", |
| "xy_gps_error", |
| "yaw_max_error", |
| ] |
| if sequential: |
| keys += [ |
| "xy_seq_error", |
| "xy_gps_seq_error", |
| "yaw_seq_error", |
| "yaw_gps_seq_error", |
| ] |
| for k in keys: |
| if k not in metrics: |
| logger.warning("Key %s not in metrics.", k) |
| continue |
| rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() |
| logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) |
| return metrics |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--experiment", type=str, required=True) |
| parser.add_argument("--split", type=str, default="val", choices=["val"]) |
| parser.add_argument("--sequential", action="store_true") |
| parser.add_argument("--output_dir", type=Path) |
| parser.add_argument("--num", type=int) |
| parser.add_argument("dotlist", nargs="*") |
| args = parser.parse_args() |
| cfg = OmegaConf.from_cli(args.dotlist) |
| run( |
| args.split, |
| args.experiment, |
| cfg, |
| args.sequential, |
| output_dir=args.output_dir, |
| num=args.num, |
| ) |
|
|