Source code for geowatch.tasks.landcover.datasets

import logging
import itertools
import kwcoco
import kwimage
import numpy as np
import torch.utils.data
from geowatch.utils import util_kwimage

log = logging.getLogger(__name__)


class _CocoTorchDataset(torch.utils.data.Dataset):
    """
    Base dataset for landcover task
    """

    def __init__(self, dset):
        self.dset = kwcoco.CocoDataset.coerce(dset)
        self.gids = sorted(list(filter(self._include, self.dset.imgs.keys())))

    def __len__(self):
        return len(self.gids)

    def __getitem__(self, idx):
        gid = self.gids[idx]
        img_info = self.dset.imgs[gid].copy()
        try:
            img_info['imgdata'] = self._load(gid)
        except Exception as ex:
            raise Exception('Unable to load image {}'.format(gid)) from ex
        return img_info

    def _include(self, gid):
        """
        Args:
            gid:
        Returns: True to include the given image in this dataset.  False to exclude.
        """
        return True

    def _load(self, img_info):
        """
        Load an image and return a numpy array.
        """
        raise NotImplementedError('subclass must override _load')

    def _load_channels_stacked(self, gid, channels_list, resolution):
        channel_images = [self._try_load_channel(gid, channels, resolution) for channels in channels_list]

        # size of largest channel from list
        dsize = max(img.shape for img in channel_images)
        dsize = (dsize[1], dsize[0])

        channel_images = [
            imresize(img, dsize=dsize, interpolation='bilinear')
            for img in channel_images
        ]
        img = np.dstack(channel_images).astype(np.float32)
        return img

    def _try_load_channel(self, gid, channels, resolution):
        if isinstance(channels, (list, tuple)):
            ex = None
            for chan in channels:
                try:
                    return self._try_load_channel(gid, chan, resolution)
                except Exception as e:
                    ex = e
            raise Exception('Unable to load any channels {} from image {}: {}'.format(channels, gid, str(ex))) from ex
        else:
            try:
                coco_img = self.dset.coco_image(gid)

                imdata = coco_img.imdelay(channels, space='image', resolution=resolution).finalize(nodata='float')
                return imdata
            except Exception as ex:
                img = self.dset.imgs[gid]
                actual_channels = img.get('channels', [aux.get('channels') for aux in img.get('auxiliary', [])])
                raise Exception(
                    'Unable to load {} from {} image with channels {}'.format(
                        channels,
                        img['sensor_coarse'],
                        actual_channels
                    )) from ex


[docs] class S2Dataset(_CocoTorchDataset): """ Load S2 images and stack. """ def __init__(self, dset): self.channels_list = [ 'coastal', 'blue', 'green', 'red', 'B05', 'B06', 'B07', 'nir', 'B8A', 'B09', 'cirrus', 'swir16', 'swir22' ] super(S2Dataset, self).__init__(dset) def _include(self, gid): sensor_type = self.dset.imgs[gid]['sensor_coarse'] available_channels = [x["channels"] for x in self.dset.imgs[gid]["auxiliary"]] # Needed to handled time-averaged input images (which combine # all channels into a single image) available_channels = set(itertools.chain(*[c.split('|') for c in available_channels])) has_valid_sensor = sensor_type == 'S2' has_valid_channels = set(self.channels_list).issubset(available_channels) return has_valid_sensor and has_valid_channels def _load(self, gid): img = self._load_channels_stacked(gid, self.channels_list, resolution=10) is_samecolor = util_kwimage.find_samecolor_regions(img[:, :, 0], scale=0.4, min_region_size=49, values={0}) img[is_samecolor > 0] = np.nan img[img == -9999] = np.nan return img
[docs] class WVDataset(_CocoTorchDataset): """ Load WorldView images and stack. """ def __init__(self, dset): self.channels_list = [ 'coastal', 'blue', 'green', 'yellow', 'red', 'rededge', 'nir08', 'nir09' ] super(WVDataset, self).__init__(dset) def _include(self, gid): sensor_type = self.dset.imgs[gid]['sensor_coarse'] available_channels = [x["channels"] for x in self.dset.imgs[gid]["auxiliary"]] # Needed to handled time-averaged input images (which combine # all channels into a single image) available_channels = set(itertools.chain(*[c.split('|') for c in available_channels])) has_valid_sensor = sensor_type == 'WV' has_valid_channels = set(self.channels_list).issubset(available_channels) return has_valid_sensor and has_valid_channels def _load(self, gid): img = self._load_channels_stacked(gid, self.channels_list, resolution=2) is_samecolor = util_kwimage.find_samecolor_regions(img[:, :, 0], scale=0.4, min_region_size=49, values={0}) img[is_samecolor > 0] = np.nan img[img == -9999] = np.nan return img
[docs] def imresize(img, **kwargs): if kwargs.get('dsize') == (img.shape[1], img.shape[0]): return img return kwimage.imresize(img, **kwargs)
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--coco_dset", type=str, required=True) args = parser.parse_args() coco_dset = WVDataset(args.coco_dset) print(len(coco_dset)) img = coco_dset._load(coco_dset.gids[0]) print(img.shape)