from transformers.models.detr.modeling_detr import (
DetrObjectDetectionOutput, build_position_encoding, DetrDecoder,
DetrMLPPredictionHead, DetrHungarianMatcher, DetrLoss
)
from transformers import DetrConfig
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from einops import rearrange
[docs]
class DetrDecoderForObjectDetection(nn.Module):
def __init__(self, config: DetrConfig, d_model=10, d_hidden=100):
super().__init__()
self.config = config
self.object_queries = build_position_encoding(config)
self.query_position_embeddings = nn.Embedding(config.num_queries, d_model)
self.decoder = DetrDecoder(config)
self.bbox_predictor = DetrMLPPredictionHead(
input_dim=d_model, hidden_dim=100, output_dim=4, num_layers=3
)
self.class_labels_classifier = nn.Linear(
config.d_model, config.num_labels + 1
)
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
[docs]
def forward(
self,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[List[dict]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
loss_only: Optional[bool] = None,
pred_boxes: Optional[torch.FloatTensor] = None,
logits: Optional[torch.FloatTensor] = None
) -> Union[Tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
if not loss_only:
#return_dict = return_dict if return_dict is not None else self.config.use_return_dict
pixel_values = inputs_embeds.permute(0, 3, 1, 2)
batch_size, num_channels, height, width = pixel_values.shape
pixel_mask = torch.ones(((batch_size, height, width))).to(pixel_values.device)
object_queries = self.object_queries(pixel_values, pixel_mask)
object_queries = object_queries.flatten(2).permute(0, 2, 1)
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
queries = torch.zeros_like(query_position_embeddings)
inputs_embeds = rearrange(inputs_embeds, 'b h w c -> b (h w) c')
decoder_outputs = self.decoder(
inputs_embeds=queries,
attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=inputs_embeds,
)
sequence_output = decoder_outputs[0]
# class logits + predicted bounding boxes
logits = self.class_labels_classifier(sequence_output)
pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
# First: create the matcher
matcher = DetrHungarianMatcher(
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
)
# Second: create the criterion
losses = ["labels", "boxes", "cardinality"]
criterion = DetrLoss(
matcher=matcher,
num_classes=self.config.num_labels,
eos_coef=self.config.eos_coefficient,
losses=losses,
)
criterion.to(pred_boxes.device)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
outputs_loss["logits"] = logits
outputs_loss["pred_boxes"] = pred_boxes
if self.config.auxiliary_loss:
# broken?
raise NotImplementedError
# intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
# outputs_class = self.class_labels_classifier(intermediate)
# outputs_coord = self.bbox_predictor(intermediate).sigmoid()
# auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
# outputs_loss["auxiliary_outputs"] = auxiliary_outputs
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
if self.config.auxiliary_loss:
aux_weight_dict = {}
for i in range(self.config.decoder_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
if not return_dict:
if auxiliary_outputs is not None:
output = (pred_boxes) + auxiliary_outputs
else:
output = (pred_boxes, logits)
return ((loss, loss_dict)) + output if loss is not None else output
return DetrObjectDetectionOutput(
loss=loss,
loss_dict=loss_dict,
logits=logits,
pred_boxes=pred_boxes,
)
def _test():
import torch
import kwimage
num = 10
true_boxes = kwimage.Boxes.random(num).tensor()
print(true_boxes.tensor)
inputs = torch.rand(num, 14, 14, 256)
config = DetrConfig()
regress = DetrDecoderForObjectDetection(config, 256)
energy = regress(inputs)
energy.retain_grad()
outputs = energy.sigmoid()
outputs.retain_grad()
out_boxes = kwimage.Boxes(outputs, 'cxywh')
ious = out_boxes.ious(true_boxes)
loss = ious.sum()
print(loss)
loss.backward()
if __name__ == '__main__':
_test()