"""
SeeAlso:
* ~/data/dvc-repos/smart_expt_dvc/models/kitware/xview_dino/package_trained_model.py
Notes:
# To test if mmcv is working on your machine:
python -c "from mmcv.ops import multi_scale_deform_attn"
"""
#!/usr/bin/env python3
import scriptconfig as scfg
import ubelt as ub
from torch.utils import data
[docs]
class BuildingDetectorConfig(scfg.DataConfig):
coco_fpath = scfg.Value(None, help='input kwcoco dataset')
out_coco_fpath = scfg.Value(None, help='output')
package_fpath = scfg.Value(None, help='pytorch packaged model')
data_workers = 2
window_dims = (1024, 1024)
fixed_resolution = "1GSD"
batch_size = 1
window_overlap = 0.5
device = scfg.Value(0)
select_images = None
track_emissions = True
[docs]
class WrapperDataset(data.Dataset):
def __init__(self, subdset):
self.subdset = subdset
def __len__(self):
return len(self.subdset)
def __getitem__(self, index):
item = self.subdset[index]
chw, _ = dino_preproc_item(item)
item['chw'] = chw
return item
[docs]
def main(cmdline=1, **kwargs):
"""
Ignore:
/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/kitware/xview_dino_detector/checkpoint_best_regular.pth
Example:
>>> # xdoctest: +SKIP
>>> from geowatch.tasks.dino_detector.predict import * # NOQA
>>> import ubelt as ub
>>> import geowatch
>>> import kwcoco
>>> dvc_data_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto')
>>> dvc_expt_dpath = geowatch.find_dvc_dpath(tags='phase2_expt', hardware='auto')
>>> coco_fpath = dvc_data_dpath / 'Drop6-MeanYear10GSD-V2/imgonly-KR_R001.kwcoco.zip'
>>> package_fpath = dvc_expt_dpath / 'models/kitware/xview_dino.pt'
>>> out_coco_fpath = ub.Path.appdir('geowatch/tests/dino/doctest0').ensuredir() / 'pred_boxes.kwcoco.zip'
>>> kwargs = {
>>> 'coco_fpath': coco_fpath,
>>> 'package_fpath': package_fpath,
>>> 'out_coco_fpath': out_coco_fpath,
>>> 'fixed_resolution': '10GSD',
>>> 'window_dims': (256, 256),
>>> }
>>> cmdline = 0
>>> _ = main(cmdline=cmdline, **kwargs)
>>> out_dset = kwcoco.CocoDataset(out_coco_fpath)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> import kwimage
>>> kwplot.plt.ion()
>>> gid = out_dset.images()[0]
>>> annots = out_dset.annots(gid=gid)
>>> dets = annots.detections
>>> list(map(len, out_dset.images().annots))
>>> config = out_dset.dataset['info'][-1]['properties']['config']
>>> delayed = out_dset.coco_image(gid).imdelay(channels='red|green|blue', resolution=config['fixed_resolution'])
>>> rgb = kwimage.normalize_intensity(delayed.finalize())
>>> import kwplot
>>> kwplot.plt.ion()
>>> kwplot.imshow(rgb, doclf=1)
>>> top_dets = dets.compress(dets.scores > 0.2)
>>> top_dets.draw()
"""
import rich
config = BuildingDetectorConfig.cli(cmdline=cmdline, data=kwargs, strict=True)
rich.print('config = ' + ub.urepr(config, nl=1))
from geowatch.tasks.fusion.utils import load_model_from_package
import torch
device = config.device
model = load_model_from_package(config.package_fpath)
model = model.eval()
model = model.to(device)
model.device = device
# Specific hacks for this specific model
model.building_id = ub.invert_dict(model.id2name)['Building']
from geowatch.tasks.fusion.datamodules import kwcoco_datamodule
datamodule = kwcoco_datamodule.KWCocoVideoDataModule(
test_dataset=config.coco_fpath,
batch_size=config.batch_size,
fixed_resolution=config.fixed_resolution,
num_workers=config.data_workers,
window_dims=config.window_dims,
select_images=config.select_images,
window_overlap=config.window_overlap,
force_bad_frames=True,
resample_invalid_frames=0,
time_steps=1,
include_sensors=['WV', 'WV1'],
channels="(WV):red|green|blue,(WV,WV1):pan",
)
datamodule.setup('test')
torch_dataset = datamodule.torch_datasets['test']
dino_dset = WrapperDataset(torch_dataset)
# loader = torch_dataset.make_loader()
# from geowatch.tasks.fusion.datamodules.kwcoco_dataset import worker_init_fn
loader = torch.utils.data.DataLoader(
dino_dset, batch_size=config.batch_size, num_workers=config.data_workers,
shuffle=False, pin_memory=False,
# worker_init_fn=worker_init_fn,
collate_fn=ub.identity, # disable collation
)
batch_iter = iter(loader)
coco_dset = torch_dataset.sampler.dset
from kwutil import util_progress
from geowatch.utils import process_context
proc = process_context.ProcessContext(
name='box.predict',
config=dict(config),
track_emissions=config.track_emissions,
)
proc.start()
proc.add_disk_info(coco_dset.fpath)
# pman = util_progress.ProgressManager('progiter')
import kwimage
pman = util_progress.ProgressManager()
gid_to_dets_accum = ub.ddict(list)
images = coco_dset.images()
gid_to_cocoimg = ub.dzip(images, images.coco_images)
# torch.set_grad_enabled(False)
with pman, torch.no_grad():
for batch in pman.progiter(batch_iter, total=len(loader), desc='🦖 dino box detector'):
batch_dets = dino_predict(model, batch)
# TODO: downweight the scores of boxes on the edge of the window.
for item, dets in zip(batch, batch_dets):
frame0 = item['frames'][0]
gid = frame0['gid']
# Compute the transform from this window outspace back to image
# space.
sl_y, sl_x = frame0['output_space_slice']
offset_x = sl_x.start
offset_y = sl_y.start
vidspace_offset = (offset_x, offset_y)
scale_out_from_vid = frame0['scale_outspace_from_vid']
scale_vid_from_out = 1 / scale_out_from_vid
warp_vid_from_out = kwimage.Affine.affine(
scale=scale_vid_from_out,
offset=vidspace_offset)
coco_img = gid_to_cocoimg[gid]
warp_img_from_vid = coco_img.warp_img_from_vid
warp_img_from_out = warp_img_from_vid @ warp_vid_from_out
imgspace_dets = dets.warp(warp_img_from_out)
gid_to_dets_accum[gid].append(imgspace_dets)
unseen_gids = set(coco_dset.images()) - set(gid_to_dets_accum)
print('unseen_gids = {}'.format(ub.urepr(unseen_gids, nl=1)))
coco_dset.clear_annotations()
obj = proc.stop()
from kwcoco.util import util_json
obj = util_json.ensure_json_serializable(obj)
out_dset = coco_dset.copy()
# TODO: graceful bundle changes
out_dset.reroot(absolute=True)
out_dset.dataset['info'].append(obj)
out_fpath = ub.Path(config.out_coco_fpath)
out_fpath.parent.ensuredir()
out_dset.fpath = out_fpath
out_dset._ensure_json_serializable()
import kwimage
gid_to_dets = ub.udict(gid_to_dets_accum).map_values(kwimage.Detections.concatenate)
for gid, dets in gid_to_dets.items():
for cls in dets.classes:
out_dset.ensure_category(cls)
dets = dets.non_max_supress(thresh=0.2, perclass=True)
for ann in dets.to_coco(dset=out_dset):
out_dset.add_annotation(**ann, image_id=gid)
# Hack, could make filtering in the dataloader easier.
images = out_dset.images()
wv_images = images.compress([
s == 'WV' for s in images.lookup('sensor_coarse')])
out_dset = out_dset.subset(wv_images)
pred_dpath = ub.Path(out_dset.fpath).parent.absolute()
rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]')
out_dset.fpath = out_fpath
out_dset.dump(out_dset.fpath, indent=' ')
return out_dset
[docs]
def dino_preproc_item(item):
import torchvision.transforms.functional as F
import kwimage
import torch
# import numpy as np
frames = item['frames']
modes = frames[0]['modes']
if 'red|green|blue' in modes:
chw = modes['red|green|blue']
is_nan = torch.isnan(chw)
rgb_nan_frac = is_nan.sum() / is_nan.numel()
else:
rgb_nan_frac = 1.0
# print('----')
# print(f'rgb_nan_frac={rgb_nan_frac}')
if rgb_nan_frac >= 0.0:
# fallback on pan
if 'pan' in modes:
pan_chw = modes['pan']
pan_is_nan = torch.isnan(pan_chw)
pan_nan_frac = pan_is_nan.sum() / pan_is_nan.numel()
# print(f'pan_nan_frac={pan_nan_frac}')
if rgb_nan_frac > pan_nan_frac:
pan_hwc = pan_chw.permute(1, 2, 0).cpu().numpy()
pan_hwc = kwimage.atleast_3channels(pan_hwc)
chw = torch.Tensor(pan_hwc).permute(2, 0, 1)
hwc = chw.permute(1, 2, 0).cpu().numpy()
normed_rgb = kwimage.normalize_intensity(hwc)
hardcoded_mean = [0.485, 0.456, 0.406]
hardcoded_std = [0.229, 0.224, 0.225]
chw = F.to_tensor(normed_rgb)
chw = F.normalize(chw, mean=hardcoded_mean, std=hardcoded_std)
chw = torch.nan_to_num(chw)
# normed_rgb = np.nan_to_num(normed_rgb)
return chw, normed_rgb
[docs]
def dino_predict(model, batch):
import torch
import kwimage
dino_batch_items = []
for item in batch:
chw = item['chw']
dino_batch_items.append(chw)
dino_batch = torch.stack(dino_batch_items, dim=0)
device = model.device
bchw = dino_batch.to(device)
raw_output = model.forward(bchw)
raw_output['pred_boxes'].shape
target_sizes = torch.Tensor([chw.shape[1:3]]).to(device)
outputs = model.postprocessors['bbox'](raw_output, target_sizes)
batch_dets = []
for output in outputs:
output = outputs[0]
dets = kwimage.Detections(
boxes=kwimage.Boxes(output['boxes'].cpu().numpy(), 'ltrb'),
class_idxs=output['labels'].cpu().numpy(),
scores=output['scores'].cpu().numpy(),
classes=list(model.id2name.values()),
)
# print(dets)
FILTER_NON_BUILDING = 0
if FILTER_NON_BUILDING:
dets = dets.compress(dets.class_idxs == model.building_id)
# Do a very small threshold first
dets = dets.compress(dets.scores > 0.01)
batch_dets.append(dets)
return batch_dets
if __name__ == '__main__':
"""
CommandLine:
xdoctest -m geowatch.tasks.dino_detector.predict
"""
main()