import torch
import torch.nn as nn
from argparse import Namespace
import pytorch_lightning as pl
from .models import UNet, UNet_blur
from .drop0_datasets import drop0_pairs
[docs]
class time_sort(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
if isinstance(hparams, dict):
hparams = Namespace(**hparams)
self.criterion = nn.BCEWithLogitsLoss()
self.accuracy = pl.metrics.classification.Accuracy()
self.save_hyperparameters(hparams)
if self.hparams.backbone == 'unet':
self.backbone = UNet(self.hparams.in_channels, hparams.feature_dim)
elif self.hparams.backbone == 'unet_blur':
self.backbone = UNet_blur(self.hparams.in_channels, hparams.feature_dim)
self.classifier = self.head(2 * hparams.feature_dim)
self.accuracy = pl.metrics.Accuracy()
self.train_data_fpath = self.hparams.train_dataset
self.val_data_fpath = self.hparams.val_dataset
[docs]
def head(self, in_channels):
return nn.Sequential(
#nn.Conv2d(in_channels, in_channels // 2, 7, bias=False, padding=3),
# nn.ReLU(),
#nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(in_channels, 1, 1, bias=False, padding=0),
)
[docs]
def forward(self, image1, image2, date1, date2):
image1 = self.backbone(image1)
image2 = self.backbone(image2)
return image1, image2, date1, date2
[docs]
def shared_step(self, batch):
image1, image2, date1, date2 = batch['image1'], batch['image2'], batch['date1'], batch['date2']
image1, image2, date1, date2 = self(image1, image2, date1, date2)
prediction = self.classifier(torch.cat((image1, image2), dim=1))
labels = torch.tensor([
tuple(date1[x]) < tuple(date2[x]) for x in range(date1.shape[0])
]).float().cuda()
labels = labels.unsqueeze(1).unsqueeze(1).repeat(1, image1.shape[2], image1.shape[3]).unsqueeze(1)
loss = self.criterion(prediction, labels)
accuracy = self.accuracy((prediction > 0.), labels.int())
output = {
# 'prediction': prediction,
# 'labels': labels,
'accuracy': accuracy,
'loss': loss,
}
return output
[docs]
def training_step(self, batch, batch_idx):
output = self.shared_step(batch)
self.log('acc', output['accuracy'])
self.log('loss', output['loss'])
return output
[docs]
def validation_step(self, batch, batch_idx):
output = self.shared_step(batch)
output = {key + "_val": val for key, val in output.items()}
self.log('val_acc', output['accuracy_val'])
self.log('val_loss', output['loss_val'])
return output
[docs]
def train_dataloader(self):
return torch.utils.data.DataLoader(
drop0_pairs(
coco_dset=self.train_data_fpath,
sensor=self.hparams.sensor,
panchromatic=self.hparams.panchromatic,
video=self.hparams.train_video,
min_time_step=self.hparams.min_time_step
),
batch_size=self.hparams.batch_size,
num_workers=self.hparams.workers
)
[docs]
def val_dataloader(self):
return torch.utils.data.DataLoader(
drop0_pairs(
coco_dset=self.val_data_fpath,
sensor=self.hparams.sensor,
panchromatic=self.hparams.panchromatic,
video=self.hparams.val_video,
min_time_step=self.hparams.min_time_step
),
batch_size=self.hparams.batch_size,
num_workers=self.hparams.workers
)