geowatch.tasks.fusion.datamodules.smart_mixins module¶
- class geowatch.tasks.fusion.datamodules.smart_mixins.SMARTDataMixin[source]¶
Bases:
object
- check_balanced_sample_tree(num=4096)[source]¶
Developer function to check statistics about how the nested pool is sampling regions.
Example
>>> # xdoctest: +REQUIRES(env:DVC_DPATH) >>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import geowatch >>> import ndsampler >>> import kwcoco >>> dvc_dpath = geowatch.find_dvc_dpath() >>> coco_fpath = dvc_dpath / 'Cropped-Drop3-TA1-2022-03-10/data_nowv_train.kwcoco.json' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> sampler = ndsampler.CocoSampler(coco_dset) >>> self = KWCocoVideoDataset( >>> sampler, >>> time_dims=5, window_dims=(256, 256), >>> window_overlap=0, >>> #channels="ASI|MF_Norm|AF|EVI|red|green|blue|swir16|swir22|nir", >>> channels="blue|green|red|nir|swir16|swir22", >>> neg_to_pos_ratio=0, time_sampling='soft2', diff_inputs=0, temporal_dropout=0.5, >>> ) >>> #self.requested_tasks['change'] = False
- if 0:
infos = [] for num in [500, 1000, 2500, 5000, 7500, 10000, 20000]:
row = self.check_balanced_sample_tree(num=num) infos.append(row)
df = pd.DataFrame(infos) import kwplot sns = kwplot.autosns()
data = df.melt(id_vars=[‘num’]) data[‘style’] = ‘raw’ data.loc[data.variable.apply(lambda x: ‘gids’ in x), ‘style’] = ‘gids’ data.loc[data.variable.apply(lambda x: ‘region’ in x), ‘style’] = ‘region’ data[‘region’] = data.variable.apply(lambda x: x.split(‘_’, 2)[-1].replace(‘seen’, ‘’) if ‘R’ in x else x) sns.lineplot(data=data, x=’num’, y=’value’, style=’style’, hue=’region’)
frac_seen = info[‘frac_gids_seen’] frac_seen[‘num’] = num frac_seen[‘ideal_seen’] = ideal_seen frac_seen[‘ideal_frac’] = ideal_frac