import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from argparse import Namespace
import pytorch_lightning as pl
from torchmetrics.classification.accuracy import Accuracy
from tqdm import tqdm
from torch import pca_lowrank as pca
import json
import torch.package
import ubelt as ub
import os
from .data.datasets import kwcoco_dataset, gridded_dataset
from .utils.attention_unet import attention_unet
[docs]
class pretext(pl.LightningModule):
TASK_NAMES = [
'sort',
'augment',
'overlap'
]
def __init__(self, hparams):
super().__init__()
if isinstance(hparams, dict):
hparams = Namespace(**hparams)
self.save_hyperparameters(hparams)
if hparams.train_dataset is not None:
# self.trainset = gridded_dataset(hparams.train_dataset, sensor=hparams.sensor, bands=hparams.bands, patch_size=hparams.patch_size)
self.trainset = kwcoco_dataset(hparams.train_dataset, sensor=hparams.sensor, bands=hparams.bands, patch_size=hparams.patch_size)
else:
self.trainset = None
if hparams.vali_dataset is not None:
self.valset = gridded_dataset(hparams.vali_dataset, sensor=hparams.sensor, bands=hparams.bands, patch_size=hparams.patch_size)
else:
self.valset = None
print('hparams = {!r}'.format(hparams))
if hasattr(hparams, 'num_channels'):
# hack for loading without dataset state
num_channels = hparams.num_channels
elif self.trainset is not None:
num_channels = self.trainset.num_channels
else:
num_channels = 6
# determine which tasks to run
self.task_indices = []
# no tasks specified
if len(self.hparams.tasks) < 1:
raise ValueError(f'tasks must be specified. Options are {", ".join(pretext.TASK_NAMES)}, or all')
# perform all tasks
elif len(self.hparams.tasks) == 1 and self.hparams.tasks[0].lower() == 'all':
self.task_indices = [i for i in range(len(pretext.TASK_NAMES))]
# run a subset of tasks
else:
for task in self.hparams.tasks:
if task.lower() in pretext.TASK_NAMES:
self.task_indices.append(pretext.TASK_NAMES.index(task.lower()))
else:
raise ValueError(f'\'{task}\' not recognized as an available task. Options are {", ".join(pretext.TASK_NAMES)}, or all')
if pretext.TASK_NAMES.index('augment') in self.task_indices and pretext.TASK_NAMES.index('overlap') not in self.task_indices:
self.task_indices.append(pretext.TASK_NAMES.index('overlap'))
# shared model body
self.backbone = attention_unet(in_channels=num_channels,
out_channels=hparams.feature_dim_shared,
pos_encode=hparams.positional_encoding,
attention_layers=hparams.attention_layers,
mode=hparams.positional_encoding_mode)
# task specific necks
self.necks = [
self.task_neck(2 * self.hparams.feature_dim_shared, self.hparams.feature_dim_each_task), # sort task
self.task_neck(self.hparams.feature_dim_shared, self.hparams.feature_dim_each_task), # augment task
self.task_neck(self.hparams.feature_dim_shared, self.hparams.feature_dim_each_task), # overlap task
]
self.necks = nn.ModuleList([ self.necks[i] for i in self.task_indices])
# task specific heads
self.heads = [
self.pixel_classification_head(self.hparams.feature_dim_each_task, num_classes=2), # sort task
self.image_classification_head( self.hparams.feature_dim_each_task), # augment task
self.image_classification_head( self.hparams.feature_dim_each_task), # overlap task
]
self.heads = nn.ModuleList([ self.heads[i] for i in self.task_indices ])
# task specific criterion
self.criteria = [
# BinaryFocalLoss(gamma=self.hparams.focal_gamma), # sort task
nn.NLLLoss(), # sort task
nn.TripletMarginLoss(), # augment task
nn.TripletMarginLoss(), # overlap task
]
self.criteria = [ self.criteria[i] for i in self.task_indices ]
# task specific metrics
self.sort_accuracy = Accuracy(task='binary')
[docs]
def forward(self, image_stack, positional_encoding=None):
# pass through shared model body
return torch.tanh(self.backbone(image_stack, positional_encoding))
[docs]
def shared_step(self, batch):
# get features of each image from shared model body
image_stack = torch.stack([batch['image1'], batch['image2'], batch['offset_image1'], batch['augmented_image1']], dim=1).to(self.device).float()
# positional_encoding must be set to none to produce viable pretext task results
out = self(image_stack, positional_encoding=None)
image1_features = out[:, 0, :, :, :]
image2_features = out[:, 1, :, :, :]
offset_image1_features = out[:, 2, :, :, :]
augmented_image1_features = out[:, 3, :, :, :]
# get time sort labels
time_labels = batch['time_sort_label']
time_sort_labels = time_labels.unsqueeze(1).unsqueeze(1).repeat(1, image_stack.shape[-2], image_stack.shape[-1]).to(self.device)
# time_sort_labels = 99 * torch.ones_like(time_labels)
# time_sort_labels[:, self.hparams.ignore_boundary:-self.hparams.ignore_boundary, self.hparams.ignore_boundary:-self.hparams.ignore_boundary] = time_labels[:, self.hparams.ignore_boundary:-self.hparams.ignore_boundary, self.hparams.ignore_boundary:-self.hparams.ignore_boundary]
losses = []
output = {}
# Time Sort task
if 0 in self.task_indices:
module_list_idx = self.task_indices.index(0)
time_sort_prediction = self.necks[module_list_idx](torch.cat((image1_features, image2_features), dim=1))
time_sort_prediction = self.heads[module_list_idx](time_sort_prediction)
# evaluate
loss_time = self.criteria[module_list_idx](time_sort_prediction.flatten(2).float(), time_sort_labels.flatten(1).long())
time_accuracy = self.sort_accuracy((torch.max(time_sort_prediction.data, 1)[1]), time_sort_labels.int())
losses.append(loss_time)
output['time_accuracy'] = time_accuracy
output['loss_time_sort'] = loss_time.detach()
output['before_after_heatmap'] = time_sort_prediction
else:
output['before_after_heatmap'] = None
# Overlap task
if 2 in self.task_indices:
module_list_idx = self.task_indices.index(2)
# forward pass through neck
image1_overlap_out = self.necks[module_list_idx](image1_features)
image2_overlap_out = self.necks[module_list_idx](image2_features)
image1_offset_overlap_out = self.necks[module_list_idx](offset_image1_features)
# forward pass through head
image1_overlap_out = self.heads[module_list_idx](image1_overlap_out)
image2_overlap_out = self.heads[module_list_idx](image2_overlap_out)
image1_offset_overlap_out = self.heads[module_list_idx](image1_offset_overlap_out)
# evaluate
loss_offset = self.criteria[module_list_idx](image1_overlap_out, image2_overlap_out, image1_offset_overlap_out)
losses.append(loss_offset)
if 2 in self.task_indices:
output['loss_offset'] = loss_offset.detach()
# Augment task
if 1 in self.task_indices:
module_list_idx = self.task_indices.index(1)
# image1 forward pass through neck
image1_augment_out = self.necks[module_list_idx](image1_features)
image1_augmented_augment_out = self.necks[module_list_idx](augmented_image1_features)
# image1 forward pass through head
image1_augment_out = self.heads[module_list_idx](image1_augment_out)
image1_augmented_augment_out = self.heads[module_list_idx](image1_augmented_augment_out)
# perform computation if not already done
if 2 not in self.task_indices:
image1_offset_overlap_out = self.heads[module_list_idx](image1_offset_overlap_out)
# image1 evaluate
loss_augment = self.criteria[module_list_idx](image1_augment_out, image1_augmented_augment_out, image1_offset_overlap_out)
losses.append(loss_augment)
output['loss_augmented'] = loss_augment.detach()
# add up loss
loss = torch.sum(torch.stack(losses))
output['loss'] = loss
return output
[docs]
def training_step(self, batch, batch_idx):
output = self.shared_step(batch)
output.pop('before_after_heatmap')
self.log_dict(output)
return output
[docs]
def validation_step(self, batch, batch_idx):
output = self.shared_step(batch)
output.pop('before_after_heatmap')
output = {"val_" + key: val for key, val in output.items()}
self.log_dict(output)
return output
[docs]
def predict(self, batch):
self.eval()
feats = {}
shared_feats = self(batch)
feats['shared'] = shared_feats
return feats['shared']
[docs]
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers, shuffle=True, pin_memory=False)
[docs]
def val_dataloader(self):
if self.hparams.vali_dataset is not None:
return DataLoader(self.valset, batch_size=self.hparams.batch_size, num_workers=self.hparams.workers, pin_memory=False)
[docs]
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
opt, step_size=self.hparams.step_size,
gamma=self.hparams.lr_gamma)
return {'optimizer': opt, 'lr_scheduler': lr_scheduler}
[docs]
def task_neck(self, in_chan, out_chan):
kernel_size = 3
stride = 1
padding = int((kernel_size - 1) / 2)
return nn.Sequential(
nn.BatchNorm2d(in_chan),
nn.ReLU(inplace=True),
nn.Conv2d(in_chan, out_chan, kernel_size, stride, padding),
nn.BatchNorm2d(out_chan),
nn.ReLU(inplace=True)
)
[docs]
def pixel_classification_head(self, in_chan, num_classes):
kernel_size = 1
stride = 1
padding = int((kernel_size - 1) / 2)
return nn.Sequential(
nn.Conv2d(in_chan, in_chan, kernel_size, stride, padding),
nn.BatchNorm2d(in_chan),
nn.ReLU(inplace=True),
nn.Conv2d(in_chan, in_chan, kernel_size, stride, padding),
nn.BatchNorm2d(in_chan),
nn.ReLU(inplace=True),
nn.Conv2d(in_chan, num_classes, 1, bias=False, padding=0),
nn.LogSoftmax(dim=1)
)
[docs]
def image_classification_head(self, in_chan):
return nn.Sequential(
nn.Flatten(1, -1)
)
# def image_classification_head(self, in_chan):
# return nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(1, -1),
# nn.Linear(in_chan, 8),
# )
[docs]
def generate_pca_matrix(self, save_path, loader, reduction_dim=6):
feature_collection = []
with torch.set_grad_enabled(False):
# TODO: option to cache or specify a specific projection matrix?
for n, batch in tqdm(enumerate(loader), desc='Calculating PCA matrix', total=50):
if n == 50:
break
image_stack = torch.stack([batch['image1'], batch['image2'], batch['offset_image1'], batch['augmented_image1']], dim=1)
features = self.forward(image_stack.to(self.device))
feature_collection.append(features.cpu())
features = None
image_stack = None
stack = torch.cat(feature_collection, dim=0).permute(0, 1, 3, 4, 2).reshape(-1, self.hparams.feature_dim_shared)
_, _, projector = pca(stack, q=reduction_dim)
stack = None
projector = projector.permute(1, 0)
if save_path:
torch.save(projector, save_path)
return projector
[docs]
def on_save_checkpoint(self, checkpoint):
if self.hparams.pca_projection_path:
save_dir = self.hparams.pca_projection_path
save_path = os.path.join(save_dir, 'pretext_pca_{}'.format(str(self.current_epoch)) + '.pt')
os.makedirs(save_dir, exist_ok=True)
self.generate_pca_matrix(save_path=save_path, loader=self.train_dataloader(), reduction_dim=self.hparams.reduction_dim)
[docs]
def save_package(self, package_path):
model = self
package_path = os.path.join(package_path, 'pretext_package.pt')
backup_attributes = {}
unsaved_attributes = [
'trainer',
'train_dataloader',
'val_dataloader',
'test_dataloader',
'_load_state_dict_pre_hooks',
]
for key in unsaved_attributes:
val = getattr(model, key)
#print(val)
if val is not None:
backup_attributes[key] = val
log_path = package_path
log_path = ub.Path(log_path)
metadata_fpaths = []
metadata_fpaths += list(log_path.glob('hparams.yaml'))
try:
for key in backup_attributes.keys():
setattr(model, key, None)
arch_name = 'pretext_model.pkl'
module_name = 'watch_tasks_invariants'
with torch.package.PackageExporter(package_path) as exp:
exp.extern('**', exclude=['geowatch.tasks.invariants.**'])
exp.intern('geowatch.tasks.invariants.**', allow_empty=False)
package_header = {
'version': '0.1.0',
'arch_name': arch_name,
'module_name': module_name,
}
exp.save_text(
'package_header', 'package_header.json',
json.dumps(package_header)
)
exp.save_pickle(module_name, arch_name, model)
for meta_fpath in metadata_fpaths:
with open(meta_fpath, 'r') as file:
text = file.read()
exp.save_text('package_header', meta_fpath.name, text)
finally:
for key, val in backup_attributes.items():
setattr(model, key, val)
[docs]
@classmethod
def load_package(cls, package_path):
"""
DEPRECATE IN FAVOR OF geowatch.tasks.fusion.utils.load_model_from_package
TODO:
- [ ] Make the logic that defines the save_package and load_package
methods with appropriate package header data a lightning
abstraction.
"""
# NOTE: there is no gaurentee that this loads an instance of THIS
# model, the model is defined by the package and the tool that loads it
# is agnostic to the model contained in said package.
# This classmethod existing is a convinience more than anything else
from geowatch.tasks.fusion.utils import load_model_from_package
self = load_model_from_package(package_path)
return self