geowatch.tasks.fusion.datamodules.smart_mixins module

class geowatch.tasks.fusion.datamodules.smart_mixins.SMARTDataMixin[source]

Bases: object

check_balanced_sample_tree(num=4096)[source]

Developer function to check statistics about how the nested pool is sampling regions.

Example

>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
>>> import geowatch
>>> import ndsampler
>>> import kwcoco
>>> dvc_dpath = geowatch.find_dvc_dpath()
>>> coco_fpath = dvc_dpath / 'Cropped-Drop3-TA1-2022-03-10/data_nowv_train.kwcoco.json'
>>> coco_dset = kwcoco.CocoDataset(coco_fpath)
>>> sampler = ndsampler.CocoSampler(coco_dset)
>>> self = KWCocoVideoDataset(
>>>     sampler,
>>>     time_dims=5, window_dims=(256, 256),
>>>     window_overlap=0,
>>>     #channels="ASI|MF_Norm|AF|EVI|red|green|blue|swir16|swir22|nir",
>>>     channels="blue|green|red|nir|swir16|swir22",
>>>     neg_to_pos_ratio=0, time_sampling='soft2', diff_inputs=0, temporal_dropout=0.5,
>>> )
>>> #self.requested_tasks['change'] = False
if 0:

infos = [] for num in [500, 1000, 2500, 5000, 7500, 10000, 20000]:

row = self.check_balanced_sample_tree(num=num) infos.append(row)

df = pd.DataFrame(infos) import kwplot sns = kwplot.autosns()

data = df.melt(id_vars=[‘num’]) data[‘style’] = ‘raw’ data.loc[data.variable.apply(lambda x: ‘gids’ in x), ‘style’] = ‘gids’ data.loc[data.variable.apply(lambda x: ‘region’ in x), ‘style’] = ‘region’ data[‘region’] = data.variable.apply(lambda x: x.split(‘_’, 2)[-1].replace(‘seen’, ‘’) if ‘R’ in x else x) sns.lineplot(data=data, x=’num’, y=’value’, style=’style’, hue=’region’)

frac_seen = info[‘frac_gids_seen’] frac_seen[‘num’] = num frac_seen[‘ideal_seen’] = ideal_seen frac_seen[‘ideal_frac’] = ideal_frac

geowatch.tasks.fusion.datamodules.smart_mixins.draw_cloudmask_viz(qa_data, rgb_data)[source]

Helper visualization