import torch
import torch.nn as nn
from datetime import date
import random
from argparse import ArgumentParser, Namespace
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import numpy as np
from .utils.attention_unet import attention_unet
from .data.datasets import kwcoco_dataset, SpaceNet7, Onera
[docs]
class change(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
if isinstance(hparams, dict):
hparams = Namespace(**hparams)
### Define backbone network.
if hparams.pretrained_checkpoint:
if hparams.pretrained_multihead:
pretrained_model = self.load_from_checkpoint(hparams.pretrained_checkpoint)
else:
pretrained_model = self.load_from_checkpoint(hparams.pretrained_checkpoint)
if hparams.pretrained_encoder_only:
self.backbone.encoder = pretrained_model.backbone.encoder
else:
self.backbone = pretrained_model.backbone
else:
self.backbone = attention_unet(hparams.num_channels, hparams.feature_dim, pos_encode=hparams.positional_encoding, num_attention_layers=hparams.num_attention_layers, mode=hparams.positional_encoding_mode)
##### define dataset
if hparams.trainset == 'kwcoco':
self.trainset = kwcoco_dataset(hparams.train_kwcoco, hparams.sensor, hparams.bands, hparams.patch_size, segmentation_labels=True, num_images=hparams.num_images)
self.valset = kwcoco_dataset(hparams.val_kwcoco, hparams.sensor, hparams.bands, hparams.patch_size, segmentation_labels=True, num_images=hparams.num_images)
elif hparams.trainset == 'spacenet':
self.trainset = SpaceNet7(hparams.patch_size, segmentation_labels=True, num_images=hparams.num_images, train=True)
self.valset = SpaceNet7(hparams.patch_size, segmentation_labels=True, num_images=hparams.num_images, train=False)
elif hparams.trainset == 'onera':
self.trainset = Onera(train=True, patch_size=hparams.patch_size, num_channels=hparams.num_channels)
self.valset = Onera(train=False, num_channels=hparams.num_channels)
assert hparams.num_images == 2
if hparams.binary:
num_classes = 2
weight = torch.FloatTensor([1, hparams.pos_class_weight])
else:
num_classes = 6
weight = torch.FloatTensor([0, 1, 1, 1, 1, 1])
self.criterion = nn.NLLLoss(weight=weight)
self.save_hyperparameters(hparams)
self.classifier = self.head(2 * hparams.feature_dim, num_classes)
[docs]
def head(self, in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size, dilation=dilation, bias=False, padding=dilation),
nn.GroupNorm(num_groups=in_channels // 8, num_channels=in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, in_channels, kernel_size, dilation=dilation, bias=False, padding=dilation),
nn.GroupNorm(num_groups=in_channels // 8, num_channels=in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, 1, bias=False, padding=0),
nn.LogSoftmax(dim=1))
[docs]
def forward(self, x, positions=None):
predictions = self.backbone(x, positions)
# _, predicted_class = torch.max(predictions, dim=2)
return {
'predictions': predictions,
# 'predicted_class': predicted_class
}
[docs]
def shared_step(self, batch):
images = [batch[key] for key in batch if key[:5] == 'image']
if self.hparams.change:
change_map = batch['change_map'].to(self.device)
else:
label = random.choice([0, 1])
change_map = label * torch.ones_like(batch['change_map']).to(self.device)
if not label:
images.reverse()
images = torch.stack(images, dim=1).to(self.device)
if self.hparams.positional_encoding:
positions = batch['time_steps'].to(self.device)
else:
positions = None
forward = self.forward(images, positions)
predictions = forward['predictions']
predictions = self.classifier(torch.cat([predictions[:, 0, :, :, :], predictions[:, 1, :, :, :]], dim=1))
loss = self.criterion(predictions, change_map.long())
output = {
'prediction_map': predictions,
'targets' : change_map,
'loss': loss
}
return output
[docs]
def training_step(self, batch, batch_idx):
output = self.shared_step(batch)
self.log('train_loss', output['loss'])
return output['loss']
[docs]
def validation_step(self, batch, batch_idx):
output = self.shared_step(batch)
self.log('validation_loss', output['loss'])
return output['loss']
[docs]
def validation_epoch_end(self, outputs):
if self.hparams.change:
epoch_test_loss, epoch_test_accuracy, cl_acc, pr_rec = self.run_test(loader=self.val_dataloader)
epoch_test_nochange_accuracy = cl_acc[0]
epoch_test_change_accuracy = cl_acc[1]
epoch_test_precision = pr_rec[0]
epoch_test_recall = pr_rec[1]
epoch_test_Fmeasure = pr_rec[2]
self.log('val_epoch_accuracy', epoch_test_accuracy)
self.log('val_epoch_accuracy_change', epoch_test_change_accuracy)
self.log('val_epoch_accuracy_no_change', epoch_test_nochange_accuracy)
self.log('val_epoch_precision', epoch_test_precision)
self.log('val_epoch_recall', epoch_test_recall)
self.log('val_epoch_f1', epoch_test_Fmeasure)
[docs]
def run_test(self, loader):
self.eval()
tot_loss = 0
tot_count = 0
n = 2
class_correct = list(0. for i in range(n))
class_total = list(0. for i in range(n))
class_accuracy = list(0. for i in range(n))
tp = 0
tn = 0
fp = 0
fn = 0
for batch in loader():
images = [batch[key] for key in batch if key[:5] == 'image']
images = torch.stack(images, dim=1).to(self.device)
output = self.shared_step(batch)
segmentations = output['targets']
output = output['prediction_map']
loss = self.criterion(output, segmentations.long())
tot_loss += loss.data * np.prod(segmentations.size())
tot_count += np.prod(segmentations.size())
_, predicted = torch.max(output.data, 1)
c = (predicted.int() == segmentations.data.int())
where_no_change = (0 == segmentations.data.int())
class_correct[0] += torch.sum(c[where_no_change])
class_total[0] += torch.sum(where_no_change)
where_change = (1 == segmentations.data.int())
class_correct[1] += torch.sum(c[where_change])
class_total[1] += torch.sum(where_change)
pr = (predicted.int() > 0).cpu().numpy()
gt = (segmentations.data.int() > 0).cpu().numpy()
tp += np.logical_and(pr, gt).sum()
tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
fp += np.logical_and(pr, np.logical_not(gt)).sum()
fn += np.logical_and(np.logical_not(pr), gt).sum()
net_loss = tot_loss / tot_count
net_accuracy = 100 * (tp + tn) / tot_count
for i in range(n):
class_accuracy[i] = 100 * class_correct[i] / max(class_total[i], 0.00001)
prec = tp / (tp + fp)
rec = tp / (tp + fn)
f_meas = 2 * prec * rec / (prec + rec)
prec_nc = tn / (tn + fn)
rec_nc = tn / (tn + fp)
pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
return net_loss, net_accuracy, class_accuracy, pr_rec
[docs]
def train_dataloader(self):
return torch.utils.data.DataLoader(self.trainset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.workers,
shuffle=True)
[docs]
def val_dataloader(self):
return torch.utils.data.DataLoader(self.valset,
batch_size=1,
num_workers=self.hparams.workers)
[docs]
def main(args):
if isinstance(args, dict):
args = Namespace(**args)
if args.change:
task = 'change'
else:
task = 'before_after'
# if args.pretrained_multihead and args.pretrained_checkpoint:
# mh = 'pretrained_multihead'
# elif args.pretrained_checkpoint:
# mh = 'sort'
# else:
# mh = 'no_pretrain'
if args.positional_encoding:
mode = args.positional_encoding_mode
else:
mode = 'none'
log_dir = '{}/{}/{}/{}/{}/{}/{}'.format(
args.save_dir,
args.trainset,
task,
'Attention_layers:' + str(args.num_attention_layers),
'Position:' + str(args.positional_encoding),
mode,
str(date.today()),
)
logger = TensorBoardLogger(log_dir)
model = change(hparams=args)
checkpoint_callback = ModelCheckpoint(monitor='val_epoch_f1', mode='max', save_top_k=1)
lr_logger = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer.from_argparse_args(args,
logger=logger,
callbacks=[checkpoint_callback, lr_logger],
log_every_n_steps=30,
check_val_every_n_epoch=args.check_val_every_n_epoch)
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser()
###train hyperparameters
parser.add_argument('--max_epochs', type=int, default=50)
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--step_size', type=int, default=10)
parser.add_argument('--lr_gamma', type=float, default=.1)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--learning_rate', type=float, default=.0001)
parser.add_argument('--save_dir', default='geowatch/tasks/invariants/logs')
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--feature_dim', type=int, default=64)
# parser.add_argument('--drop_rate', type=float, default=.2)
###head network
parser.add_argument('--kernel_size', type=int, default=3)
parser.add_argument('--dilation', type=int, default=1)
###dataset
parser.add_argument('--trainset', type=str, help='Choose from: spacenet, onera, or kwcoco.', default='onera')
parser.add_argument('--valset', type=str, help='Choose from: spacenet, onera, or kwcoco. If blank, valset will correspond to chosen trainset.', default='onera')
parser.add_argument('--patch_size', type=int, default=128)
parser.add_argument('--num_channels', type=int, default=10)
### kwcoco arguments
parser.add_argument('--train_kwcoco', type=str, default='')
parser.add_argument('--val_kwcoco', type=str, default='')
parser.add_argument('--sensor', type=str, nargs='+', default=['S2', 'L8'])
parser.add_argument('--bands', type=str, nargs='+', default=['shared'])
### spacenet arguments
parser.add_argument('--remove_clouds', help='spacenet specific argument', action='store_true')
parser.add_argument('--normalize_spacenet', help='spacenet specific argument', action='store_true')
parser.add_argument('--check_val_every_n_epoch', type=int, default=5)
### pretraining arguments
parser.add_argument('--pretrained_checkpoint', type=str, help='path to pretrained checkpoint. Leave blank for change detection training without pretraining.', default='')
parser.add_argument('--pretrained_multihead', action='store_true', help='indicate if the pretrained checkpoint was trained in a multihead fashion')
parser.add_argument('--pretrained_encoder_only', action='store_true')
### main argument
parser.add_argument('--binary', action='store_true')
parser.add_argument('--pos_class_weight', type=float, help='Weight on positive class for segmentation. Only used on binary labels.', default=10)
parser.add_argument('--num_images', type=int, default=2)
parser.add_argument('--num_attention_layers', type=int, default=4)
parser.add_argument('--positional_encoding', action='store_true')
parser.add_argument('--positional_encoding_mode', type=str, help='addition or concatenation', default='concatenation')
parser.add_argument('--change', action='store_true')
args = parser.parse_args()
main(args)