"""
import sys
sys.path.append('/home/joncrall/code/SITS-Former/code')
from model import classification_model as clf
import liberator
lib = liberator.Liberator()
lib.add_dynamic(clf.BERTClassification)
lib.expand(['model'])
print(lib.current_sourcecode())
"""
from torch.nn.modules import LayerNorm
from torch.nn.modules.transformer import TransformerEncoder
from torch.nn.modules.transformer import TransformerEncoderLayer
import math
import torch.nn as nn
import torch
[docs]
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=366):
super().__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len + 1, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float(
).unsqueeze(1) # [max_len, 1]
div_term = (torch.arange(0, d_model, 2).float() * -
(math.log(10000.0) / d_model)).exp() # [d_model/2,]
# keep pe[0,:] to zeros
# broadcasting to [max_len, d_model/2]
pe[1:, 0::2] = torch.sin(position * div_term)
# broadcasting to [max_len, d_model/2]
pe[1:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
[docs]
def forward(self, time):
output = torch.stack([torch.index_select(self.pe, 0, time[i, :])
for i in range(time.shape[0])], dim=0)
return output # [batch_size, seq_length, embed_dim]
[docs]
class BERTEmbedding(nn.Module):
"""
BERT Embedding which is consisted with under features
1. InputEmbedding : project the input to embedding size through a lightweight 3D-CNN
2. PositionalEncoding : adding positional information using sin/cos functions
sum of both features are output of BERTEmbedding
"""
def __init__(self, num_features, dropout=0.1):
"""
:param num_features: number of input features
:param dropout: dropout rate
"""
super().__init__()
channel_size = (32, 64, 256)
kernel_size = (5, 3, 5, 3)
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=1,
out_channels=channel_size[0],
kernel_size=(kernel_size[0], kernel_size[1], kernel_size[1])),
nn.ReLU(),
nn.BatchNorm3d(channel_size[0]),
)
self.conv2 = nn.Sequential(
nn.Conv3d(in_channels=channel_size[0],
out_channels=channel_size[1],
kernel_size=(kernel_size[2], kernel_size[3], kernel_size[3])),
nn.ReLU(),
nn.BatchNorm3d(channel_size[1]),
)
self.linear = nn.Linear(in_features=channel_size[1] * 2,
out_features=channel_size[2])
self.embed_size = channel_size[-1]
self.position = PositionalEncoding(
d_model=self.embed_size, max_len=366)
self.dropout = nn.Dropout(p=dropout)
[docs]
def forward(self, input_sequence, doy_sequence):
batch_size = input_sequence.size(0)
seq_length = input_sequence.size(1)
band_num = input_sequence.size(2)
patch_size = input_sequence.size(3)
first_dim = batch_size * seq_length
obs_embed = input_sequence.view(
first_dim, band_num, patch_size, patch_size).unsqueeze(1)
obs_embed = self.conv1(obs_embed)
obs_embed = self.conv2(obs_embed)
# [batch_size*seq_length, embed_size]
obs_embed = self.linear(obs_embed.view(first_dim, -1))
obs_embed = obs_embed.view(batch_size, seq_length, -1)
position_embed = self.position(doy_sequence)
x = obs_embed + position_embed # [batch_size, seq_length, embed_size]
return self.dropout(x)
[docs]
class BERT(nn.Module):
def __init__(self, num_features, hidden,
n_layers, attn_heads, dropout=0.1):
"""
:param num_features: number of input features
:param hidden: hidden size of the SITS-Former model
:param n_layers: numbers of Transformer blocks (layers)
:param attn_heads: number of attention heads
:param dropout: dropout rate
"""
super().__init__()
self.hidden = hidden
self.n_layers = n_layers
self.attn_heads = attn_heads
feed_forward_hidden = hidden * 4
self.embedding = BERTEmbedding(num_features)
encoder_layer = TransformerEncoderLayer(
hidden, attn_heads, feed_forward_hidden, dropout)
encoder_norm = LayerNorm(hidden)
self.transformer_encoder = TransformerEncoder(
encoder_layer, n_layers, encoder_norm)
[docs]
def forward(self, x, doy, mask):
mask = mask == 0
x = self.embedding(input_sequence=x, doy_sequence=doy)
x = x.transpose(0, 1)
x = self.transformer_encoder(x, src_key_padding_mask=mask)
x = x.transpose(0, 1)
return x
[docs]
class MulticlassClassification(nn.Module):
def __init__(self, hidden, num_classes):
super().__init__()
self.linear = nn.Linear(hidden, num_classes)
[docs]
def forward(self, x, mask):
mask = (1 - mask.unsqueeze(-1)) * 1e6
x = x - mask # mask invalid timesteps
x, _ = torch.max(x, dim=1) # max-pooling
x = self.linear(x)
return x
[docs]
class BERTClassification(nn.Module):
"""
Downstream task: Satellite Time Series Classification
"""
def __init__(self, bert: BERT, num_classes):
"""
:param bert: the BERT-Former model
:param num_classes: number of classes to be classified
"""
super().__init__()
self.bert = bert
self.classification = MulticlassClassification(
self.bert.hidden, num_classes)
[docs]
def forward(self, x, doy, mask):
x = self.bert(x, doy, mask) # [batch_size, seq_length, embed_size]
return self.classification(x, mask)