Source code for geowatch.tasks.landcover.model_info

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union

import kwcoco
import torch.utils.data

from . import detector
from .datasets import S2Dataset, WVDataset

log = logging.getLogger(__name__)


[docs] class ModelInfo(ABC):
[docs] @abstractmethod def create_dataset(self, coco_dset: Union[kwcoco.CocoDataset, str]) -> torch.utils.data.Dataset: """ Create a torch dataset compatible with this model using a CocoDataset Args: coco_dset: CocoDataset of string filepath Returns: torch dataset """ pass
@property @abstractmethod def model_outputs(self): pass
[docs] @abstractmethod def load_model(self, weights_filename: Path, device): pass
[docs] class S2ModelInfo(ModelInfo): """ This model was trained on 13-band Sentinel-2 data with 5 segmentation classes """
[docs] def create_dataset(self, coco_dset): return S2Dataset(coco_dset)
@property def model_outputs(self): return [ 'water', 'forest', 'field', 'impervious', 'barren', ]
[docs] def load_model(self, weights_filename, device): assert len(self.model_outputs) == 5 return detector.load_model(weights_filename, num_outputs=5, num_channels=13, device=device)
[docs] class WVModelInfo(ModelInfo): """ This model was trained on 8-band WorldView-3 data with 5 segmentation classes """
[docs] def create_dataset(self, coco_dset): return WVDataset(coco_dset)
@property def model_outputs(self): return [ 'water', 'forest', 'field', 'impervious', 'barren', ]
[docs] def load_model(self, weights_filename, device): assert len(self.model_outputs) == 5 return detector.load_model(weights_filename, num_outputs=5, num_channels=8, device=device)
__mapping = { 'sentinel2': S2ModelInfo, 'worldview': WVModelInfo, }
[docs] def lookup_model_info(weights_filename: Path) -> ModelInfo: model_info_class = __mapping.get(weights_filename.stem) if not model_info_class: raise Exception('unknown weights file') return model_info_class()