geowatch.tasks.fusion.methods.torchvision_nets module

Example

from geowatch.tasks.fusion.datamodules.network_io import RGBImageBatchItem item0 = RGBImageBatchItem.demo(index=0) item1 = RGBImageBatchItem.demo(index=1) batch_items = [item0, item1] self.imdata_chw.shape self.channels

from geowatch.tasks.fusion.methods.torchvision_efficientnet import * # NOQA self = FCNResNet50() batch = self.collate(batch_items) out = self.forward(batch)

self = EfficientNetB7() out = self.forward(batch)

class geowatch.tasks.fusion.methods.torchvision_nets.TorchvisionWrapper(*args: Any, **kwargs: Any)[source]

Bases: LightningModule

define_ots_model()[source]

This should define the backbone model.

forward(batch)[source]
collate(batch_items)[source]
class geowatch.tasks.fusion.methods.torchvision_nets.TorchvisionSegmentationWrapper(*args: Any, **kwargs: Any)[source]

Bases: TorchvisionWrapper

class geowatch.tasks.fusion.methods.torchvision_nets.TorchvisionClassificationWrapper(*args: Any, **kwargs: Any)[source]

Bases: TorchvisionWrapper

class geowatch.tasks.fusion.methods.torchvision_nets.TorchvisionDetectionWrapper(*args: Any, **kwargs: Any)[source]

Bases: TorchvisionWrapper

class geowatch.tasks.fusion.methods.torchvision_nets.EfficientNetB7(*args: Any, **kwargs: Any)[source]

Bases: TorchvisionClassificationWrapper

define_ots_model()[source]
class geowatch.tasks.fusion.methods.torchvision_nets.FCNResNet50(heads, classes=10, dataset_stats=None)[source]

Bases: TorchvisionSegmentationWrapper, WatchModuleMixins

define_ots_model()[source]
forward(batch)[source]
save_package(package_path, verbose=1)[source]
forward_step(batch, batch_idx=None, with_loss=True)[source]
class geowatch.tasks.fusion.methods.torchvision_nets.Resnet50(heads, classes=None, dataset_stats=None)[source]

Bases: TorchvisionClassificationWrapper, WatchModuleMixins

forward(batch)[source]
forward_step(batch, with_loss=False, stage='unspecified')[source]
define_ots_model()[source]
training_step(batch)[source]
on_before_batch_transfer(batch_items, dataloader_idx)[source]