geowatch.utils.util_netharn module¶
Ported bits of netharn. This does not include any of the analytic output-shape-for methods.
- class geowatch.utils.util_netharn.Optimizer[source]¶
Bases:
object
Old netharn.api.Optimizer class. Ideally this is deprecated.
- static coerce(config={}, **kw)[source]¶
- Accepts keywords:
- optimizer / optim :
can be sgd, adam, adamw, rmsprop
- learning_rate / lr :
a float
- weight_decay / decay :
a float
- momentum:
a float, only used if the optimizer accepts it
- params:
This is a SPECIAL keyword that is handled differently. It is interpreted by netharn.hyper.Hyperparams.make_optimizer.
In this simplest case you can pass “params” as a list of torch parameter objects or a list of dictionaries containing param groups and special group options (just as you would when constructing an optimizer from scratch). We don’t recommend this while using netharn unless you know what you are doing (Note, that params will correctly change device if the model is mounted).
In the case where you do not want to group parameters with different options, it is best practice to simply not specify params.
In the case where you want to group parameters set params to either a List[Dict] or a Dict[str, Dict].
The items / values of this collection should be a dictionary. The keys / values of this dictionary should be the per-group optimizer options. Additionally, there should be a key “params” (note this is a nested per-group params not to be confused with the top-level “params”).
Each per-group “params” should be either (1) a list of parameter names (preferred), (2) a string that specifies a regular expression (matching layer names will be included in this group), or (3) a list of parameter objects.
For example, the top-level params might look like:
- params={
‘head’: {‘lr’: 0.003, ‘params’: ‘.*head.*’}, ‘backbone’: {‘lr’: 0.001, ‘params’: ‘.*backbone.*’}, ‘preproc’: {‘lr’: 0.0, ‘params’: [
‘model.conv1’, ‘model.norm1’, , ‘model.relu1’]}
}
Note that head and backbone specify membership via regular expression whereas preproc explicitly specifies a list of parameter names.
Notes
pip install torch-optimizer
References
https://datascience.stackexchange.com/questions/26792/difference-between-rmsprop-with-momentum-and-adam-optimizers https://github.com/jettify/pytorch-optimizer
CommandLine
xdoctest -m /home/joncrall/code/netharn/netharn/api.py Optimizer.coerce
Example
>>> config = {'optimizer': 'sgd', 'params': [ >>> {'lr': 3e-3, 'params': '.*head.*'}, >>> {'lr': 1e-3, 'params': '.*backbone.*'}, >>> ]} >>> optim_ = Optimizer.coerce(config)
>>> # xdoctest: +REQUIRES(module:torch_optimizer) >>> config = {'optimizer': 'DiffGrad'} >>> optim_ = Optimizer.coerce(config, lr=1e-5) >>> print('optim_ = {!r}'.format(optim_)) >>> assert optim_[1]['lr'] == 1e-5
>>> config = {'optimizer': 'Yogi'} >>> optim_ = Optimizer.coerce(config) >>> print('optim_ = {!r}'.format(optim_))
>>> Optimizer.coerce({'optimizer': 'ASGD'})
- class geowatch.utils.util_netharn.Initializer[source]¶
Bases:
object
Base class for initializers
- history()[source]¶
Initializer methods have histories which are short for algorithms and can be quite long for pretrained models
- get_initkw()[source]¶
Initializer methods have histories which are short for algorithms and can be quite long for pretrained models
- static coerce(config={}, **kw)[source]¶
Accepts ‘init’, ‘pretrained’, ‘pretrained_fpath’, ‘leftover’, and ‘noli’.
- Parameters:
config (dict | str) – coercable configuration dictionary. if config is a string it is taken as the value for “init”.
- Returns:
initializer_ = initializer_cls, kw
- Return type:
Tuple[Initializer, dict]
Examples
>>> from geowatch.utils.util_netharn import * # NOQA >>> print(ub.urepr(Initializer.coerce({'init': 'noop'}))) >>> config = { ... 'init': 'pretrained', ... 'pretrained_fpath': '/fit/nice/untitled' ... } >>> print(ub.urepr(Initializer.coerce(config))) >>> print(ub.urepr(Initializer.coerce({'init': 'kaiming_normal'})))
- class geowatch.utils.util_netharn.NoOp[source]¶
Bases:
Initializer
An initializer that does nothing, which is useful when you have initialized the weights yourself.
Example
>>> import copy >>> self = NoOp() >>> model = ToyNet2d() >>> old_state = sum(v.sum() for v in model.state_dict().values()) >>> self(model) >>> new_state = sum(v.sum() for v in model.state_dict().values()) >>> assert old_state == new_state >>> assert self.history() is None
- class geowatch.utils.util_netharn.Orthogonal(gain=1)[source]¶
Bases:
Initializer
Same as Orthogonal, but uses pytorch implementation
Example
>>> self = Orthogonal() >>> model = ToyNet2d() >>> try: >>> self(model) >>> except RuntimeError: >>> import pytest >>> pytest.skip('geqrf: Lapack probably not availble') >>> layer = torch.nn.modules.Conv2d(3, 3, 3) >>> self(layer)
- class geowatch.utils.util_netharn.KaimingUniform(param=0, mode='fan_in')[source]¶
Bases:
Initializer
Same as HeUniform, but uses pytorch implementation
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> self = KaimingUniform() >>> model = ToyNet2d() >>> self(model) >>> layer = torch.nn.modules.Conv2d(3, 3, 3) >>> self(layer)
- class geowatch.utils.util_netharn.KaimingNormal(param=0, mode='fan_in')[source]¶
Bases:
Initializer
Same as HeNormal, but uses pytorch implementation
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> self = KaimingNormal() >>> model = ToyNet2d() >>> self(model) >>> layer = torch.nn.modules.Conv2d(3, 3, 3) >>> self(layer)
- geowatch.utils.util_netharn.apply_initializer(input, func, funckw)[source]¶
Recursively initializes the input using a torch.nn.init function.
If the input is a model, then only known layer types are initialized.
- Parameters:
input (Tensor | Module) – can be a model, layer, or tensor
func (callable) – initialization function
funckw (dict)
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> from torch import nn >>> import torch >>> class DummyNet(nn.Module): >>> def __init__(self, n_channels=1, n_classes=10): >>> super(DummyNet, self).__init__() >>> self.conv = nn.Conv2d(n_channels, 10, kernel_size=5) >>> self.norm = nn.BatchNorm2d(10) >>> self.param = torch.nn.Parameter(torch.rand(3)) >>> self = DummyNet() >>> func = nn.init.kaiming_normal_ >>> apply_initializer(self, func, {}) >>> func = nn.init.constant_ >>> apply_initializer(self, func, {'val': 42}) >>> assert np.all(self.conv.weight.detach().numpy() == 42) >>> assert np.all(self.conv.bias.detach().numpy() == 0), 'bias is always init to zero' >>> assert np.all(self.norm.bias.detach().numpy() == 0), 'bias is always init to zero' >>> assert np.all(self.norm.weight.detach().numpy() == 1) >>> assert np.all(self.norm.running_mean.detach().numpy() == 0.0) >>> assert np.all(self.norm.running_var.detach().numpy() == 1.0)
- geowatch.utils.util_netharn.trainable_layers(model, names=False)[source]¶
Returns all layers containing trainable parameters
Notes
It may be better to simply use model.named_parameters() instead in most situation. This is useful when you need the classes that contains the parameters instead of the parameters themselves.
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> import torchvision >>> model = torchvision.models.AlexNet() >>> list(trainable_layers(model, names=True))
- geowatch.utils.util_netharn.number_of_parameters(model, trainable=True)[source]¶
Returns number of trainable parameters in a torch module
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> model = torch.nn.Conv1d(2, 3, 5) >>> print(number_of_parameters(model)) 33
- class geowatch.utils.util_netharn.ToyNet2d(input_channels=1, num_classes=2)[source]¶
Bases:
Module
Demo model for a simple 2 class learning problem
- class geowatch.utils.util_netharn.ToyData2d(size=4, border=1, n=100, rng=None)[source]¶
Bases:
Dataset
Simple black-on-white and white-on-black images.
- Parameters:
n (int, default=100) – dataset size
size (int, default=4) – width / height
border (int, default=1) – border mode
rng (RandomCoercable, default=None) – seed or random state
CommandLine
python -m netharn.data.toydata ToyData2d --show
Example
>>> self = ToyData2d() >>> data1, label1 = self[0] >>> data2, label2 = self[-1] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> plt = kwplot.autoplt() >>> kwplot.figure(fnum=1, doclf=True) >>> kwplot.imshow(data1.numpy().squeeze(), pnum=(1, 2, 1)) >>> kwplot.imshow(data2.numpy().squeeze(), pnum=(1, 2, 2)) >>> kwplot.show_if_requested()
- class geowatch.utils.util_netharn.InputNorm(mean=None, std=None)[source]¶
Bases:
Module
Normalizes the input by shifting and dividing by a scale factor.
This allows for the network to take care of 0-mean 1-std normalization. The developer explicitly specifies what these shift and scale values are. By specifying this as a layer (instead of a data preprocessing step), the exporter will remember and associated this information with any deployed model. This means that a user does not need to remember what these shit/scale arguments were before passing inputs to a network.
If the mean and std arguments are unspecified, this layer becomes a noop.
References
Example
>>> self = InputNorm(mean=50.0, std=29.0) >>> inputs = torch.rand(2, 3, 5, 7) * 100 >>> outputs = self(inputs) >>> # If mean and std are unspecified, this becomes a noop. >>> assert torch.all(InputNorm()(inputs) == inputs) >>> # Specifying either the mean or the std is ok. >>> partial1 = InputNorm(mean=50)(inputs) >>> partial2 = InputNorm(std=29)(inputs)
- class geowatch.utils.util_netharn.MultiLayerPerceptronNd(dim, in_channels, hidden_channels, out_channels, bias=True, dropout=None, noli='relu', norm='batch', residual=False, noli_output=False, norm_output=False, standardize_weights=False)[source]¶
Bases:
Module
A multi-layer perceptron network for n dimensional data
Choose the number and size of the hidden layers, number of output channels, wheather to user residual connections or not, nonlinearity, normalization, dropout, and more.
- Parameters:
dim (int) – specify if the data is 0, 1, 2, 3, or 4 dimensional.
in_channels (int) – number of input channels
hidden_channels (List[int]) – or an int specifying the number of hidden layers (we choose the channel size to linearly interpolate between input and output channels)
out_channels (int) – number of output channels
dropout (float, default=None) – amount of dropout to use between 0 and 1
norm (str, default=’batch’) – type of normalization layer (e.g. batch or group), set to None for no normalization.
noli (str, default=’relu’) – type of nonlinearity
residual (bool, default=False) – if true includes a resitual skip connection between inputs and outputs.
norm_output (bool, default=True) – if True, applies a final normalization layer to the output.
noli_output (bool, default=True) – if True, applies a final nonlineary to the output.
standardize_weights (bool, default=False) – Use weight standardization
Example
>>> kw = {'dim': 0, 'in_channels': 2, 'out_channels': 1} >>> model0 = MultiLayerPerceptronNd(hidden_channels=0, **kw) >>> model1 = MultiLayerPerceptronNd(hidden_channels=1, **kw) >>> model2 = MultiLayerPerceptronNd(hidden_channels=2, **kw) >>> print('model0 = {!r}'.format(model0)) >>> print('model1 = {!r}'.format(model1)) >>> print('model2 = {!r}'.format(model2))
>>> kw = {'dim': 0, 'in_channels': 2, 'out_channels': 1, 'residual': True} >>> model0 = MultiLayerPerceptronNd(hidden_channels=0, **kw) >>> model1 = MultiLayerPerceptronNd(hidden_channels=1, **kw) >>> model2 = MultiLayerPerceptronNd(hidden_channels=2, **kw) >>> print('model0 = {!r}'.format(model0)) >>> print('model1 = {!r}'.format(model1)) >>> print('model2 = {!r}'.format(model2))
Example
>>> import ubelt as ub >>> self = MultiLayerPerceptronNd(dim=1, in_channels=128, hidden_channels=3, out_channels=2) >>> print(self) MultiLayerPerceptronNd...
- geowatch.utils.util_netharn.rectify_nonlinearity(key=NoParam, dim=2)[source]¶
Allows dictionary based specification of a nonlinearity
Example
>>> rectify_nonlinearity('relu') ReLU(...) >>> rectify_nonlinearity('leaky_relu') LeakyReLU(negative_slope=0.01...) >>> rectify_nonlinearity(None) None >>> rectify_nonlinearity('swish')
- geowatch.utils.util_netharn.rectify_normalizer(in_channels, key=NoParam, dim=2, **kwargs)[source]¶
Allows dictionary based specification of a normalizing layer
- Parameters:
in_channels (int) – number of input channels
dim (int) – dimensionality
**kwargs – extra args
Example
>>> rectify_normalizer(8) BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> rectify_normalizer(8, 'batch') BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> rectify_normalizer(8, {'type': 'batch'}) BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> rectify_normalizer(8, 'group') GroupNorm(4, 8, eps=1e-05, affine=True) >>> rectify_normalizer(8, {'type': 'group', 'num_groups': 2}) GroupNorm(2, 8, eps=1e-05, affine=True) >>> rectify_normalizer(8, dim=3) BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) >>> rectify_normalizer(8, None) None >>> rectify_normalizer(8, key={'type': 'syncbatch'}) >>> rectify_normalizer(8, {'type': 'group', 'num_groups': 'auto'}) >>> rectify_normalizer(1, {'type': 'group', 'num_groups': 'auto'}) >>> rectify_normalizer(16, {'type': 'group', 'num_groups': 'auto'}) >>> rectify_normalizer(32, {'type': 'group', 'num_groups': 'auto'}) >>> rectify_normalizer(64, {'type': 'group', 'num_groups': 'auto'}) >>> rectify_normalizer(1024, {'type': 'group', 'num_groups': 'auto'})
- class geowatch.utils.util_netharn.Identity[source]¶
Bases:
Sequential
A identity-function layer.
Example
>>> import torch >>> self = Identity() >>> a = torch.rand(3, 3) >>> b = self(a) >>> assert torch.all(a == b)
- class geowatch.utils.util_netharn.Conv0d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', standardize_weights=False)[source]¶
Bases:
Linear
- class geowatch.utils.util_netharn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, standardize_weights=False)[source]¶
Bases:
Conv1d
- extra_repr()¶
- class geowatch.utils.util_netharn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, standardize_weights=False)[source]¶
Bases:
Conv2d
- extra_repr()¶
- class geowatch.utils.util_netharn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, standardize_weights=False)[source]¶
Bases:
Conv3d
- extra_repr()¶
- geowatch.utils.util_netharn.weight_standardization_nd(dim, weight, eps)[source]¶
Note: input channels must be greater than 1!
- class geowatch.utils.util_netharn.ConvNormNd(dim, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, noli='relu', norm='batch', standardize_weights=False)[source]¶
Bases:
Sequential
Backbone convolution component. The convolution hapens first, normalization and nonlinearity happen after the convolution.
CONV[->NORM][->NOLI]
- Parameters:
dim (int) – dimensionality of the convolutional kernel (can be 0, 1, 2, or 3).
in_channels (int)
out_channels (int)
kernel_size (int | Tuple)
stride (int | Tuple)
padding (int | Tuple)
dilation (int | Tuple)
groups (int)
bias (bool)
norm (str, dict, nn.Module) – Type of normalizer, if None, then normalization is disabled.
noli (str, dict, nn.Module) – Type of nonlinearity, if None, then normalization is disabled.
standardize_weights (bool, default=False) – Implements weight standardization as described in Qiao 2020 - “Micro-Batch Training with Batch-Channel Normalization and Weight Standardization”- https://arxiv.org/pdf/1903.10520.pdf
Example
>>> self = ConvNormNd(dim=2, in_channels=16, out_channels=64, >>> kernel_size=3) >>> print(self) ConvNormNd( (conv): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1)) (norm): BatchNorm2d(64, ...) (noli): ReLU(...) )
Example
>>> self = ConvNormNd(dim=0, in_channels=16, out_channels=64) >>> print(self) ConvNormNd( (conv): Conv0d(in_features=16, out_features=64, bias=True) (norm): BatchNorm1d(64, ...) (noli): ReLU(...) ) >>> input_shape = (None, 16)
- class geowatch.utils.util_netharn.ConvNorm1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, noli='relu', norm='batch', standardize_weights=False)[source]¶
Bases:
ConvNormNd
Backbone convolution component. The convolution hapens first, normalization and nonlinearity happen after the convolution.
CONV[->NORM][->NOLI]
- Parameters:
norm (str, dict, nn.Module) – Type of normalizer, if None, then normalization is disabled.
noli (str, dict, nn.Module) – Type of nonlinearity, if None, then normalization is disabled.
Example
>>> input_shape = [2, 3, 5] >>> self = ConvNorm1d(input_shape[1], 7, kernel_size=3)
- class geowatch.utils.util_netharn.ConvNorm2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, noli='relu', norm='batch', standardize_weights=False)[source]¶
Bases:
ConvNormNd
Backbone convolution component. The convolution hapens first, normalization and nonlinearity happen after the convolution.
CONV[->NORM][->NOLI]
- Parameters:
norm (str, dict, nn.Module) – Type of normalizer, if None, then normalization is disabled.
noli (str, dict, nn.Module) – Type of nonlinearity, if None, then normalization is disabled.
Example
>>> input_shape = [2, 3, 5, 7] >>> self = ConvNorm2d(input_shape[1], 11, kernel_size=3)
- class geowatch.utils.util_netharn.ConvNorm3d(in_channels, out_channels, kernel_size, stride=1, bias=True, padding=0, noli='relu', norm='batch', groups=1, standardize_weights=False)[source]¶
Bases:
ConvNormNd
Backbone convolution component. The convolution hapens first, normalization and nonlinearity happen after the convolution.
CONV[->NORM][->NOLI]
- Parameters:
norm (str, dict, nn.Module) – Type of normalizer, if None, then normalization is disabled.
noli (str, dict, nn.Module) – Type of nonlinearity, if None, then normalization is disabled.
Example
>>> input_shape = [2, 3, 5, 7, 11] >>> self = ConvNorm3d(input_shape[1], 13, kernel_size=3)
- class geowatch.utils.util_netharn.Swish(beta=1.0)[source]¶
Bases:
Module
When beta=1 this is Sigmoid-weighted Linear Unit (SiL)
x * torch.sigmoid(x)
References
https://arxiv.org/pdf/1710.05941.pdf
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> x = torch.linspace(-20, 20, 100, requires_grad=True) >>> self = Swish() >>> y = self(x) >>> y.sum().backward() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.multi_plot(xydata={'beta=1': (x.data, y.data)}, fnum=1, pnum=(1, 2, 1), >>> ylabel='swish(x)', xlabel='x', title='activation') >>> kwplot.multi_plot(xydata={'beta=1': (x.data, x.grad)}, fnum=1, pnum=(1, 2, 2), >>> ylabel='𝛿swish(x) / 𝛿(x)', xlabel='x', title='gradient') >>> kwplot.show_if_requested()
- geowatch.utils.util_netharn.beta_mish(input, beta=1.5)[source]¶
- Applies the β mish function element-wise:
- \[\beta mish(x) = x * tanh(ln((1 + e^{x})^{\beta}))\]
See additional documentation for
echoAI.Activation.Torch.beta_mish
.References
https://github.com/digantamisra98/Echo/blob/master/echoAI/Activation/Torch/functional.py
- class geowatch.utils.util_netharn.Mish_Function(*args, **kwargs)[source]¶
Bases:
Function
Applies the mish function element-wise:
- Math:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
- Shape:
References
https://github.com/digantamisra98/Echo/blob/master/echoAI/Activation/Torch/mish.py
Examples
>>> m = Mish() >>> input = torch.randn(2) >>> output = m(input)
- class geowatch.utils.util_netharn.Mish[source]¶
Bases:
Module
Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
- Shape:
References
https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py https://github.com/thomasbrandon/mish-cuda https://arxiv.org/pdf/1908.08681v2.pdf
Examples
>>> m = Mish() >>> input = torch.randn(2) >>> output = m(input)
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> x = torch.linspace(-20, 20, 100, requires_grad=True) >>> self = Mish() >>> y = self(x) >>> y.sum().backward() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.multi_plot(xydata={'beta=1': (x.data, y.data)}, fnum=1, pnum=(1, 2, 1)) >>> kwplot.multi_plot(xydata={'beta=1': (x.data, x.grad)}, fnum=1, pnum=(1, 2, 2)) >>> kwplot.show_if_requested()
- geowatch.utils.util_netharn.default_kwargs(cls)[source]¶
Grab initkw defaults from the constructor
- Parameters:
cls (type | callable) – a class or function
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> import torch >>> import ubelt as ub >>> cls = torch.optim.Adam >>> default_kwargs(cls) >>> cls = KaimingNormal >>> print(ub.repr2(default_kwargs(cls), nl=0)) {'mode': 'fan_in', 'param': 0} >>> cls = NoOp >>> default_kwargs(cls) {}
- SeeAlso:
xinspect.get_func_kwargs(cls)
- geowatch.utils.util_netharn.padded_collate(inbatch, fill_value=-1)[source]¶
Used for detection datasets with boxes.
Example
>>> from geowatch.utils.util_netharn import * # NOQA >>> import torch >>> rng = np.random.RandomState(0) >>> inbatch = [] >>> bsize = 7 >>> for i in range(bsize): >>> # add an image and some dummy bboxes to the batch >>> img = torch.rand(3, 8, 8) # dummy 8x8 image >>> n = 11 if i == 3 else rng.randint(0, 11) >>> boxes = torch.rand(n, 4) >>> item = (img, boxes) >>> inbatch.append(item) >>> out_batch = padded_collate(inbatch) >>> assert len(out_batch) == 2 >>> assert list(out_batch[0].shape) == [bsize, 3, 8, 8] >>> assert list(out_batch[1].shape) == [bsize, 11, 4]
Example
>>> import torch >>> rng = np.random.RandomState(0) >>> inbatch = [] >>> bsize = 4 >>> for _ in range(bsize): >>> # add an image and some dummy bboxes to the batch >>> img = torch.rand(3, 8, 8) # dummy 8x8 image >>> #boxes = torch.empty(0, 4) >>> boxes = torch.FloatTensor() >>> item = (img, [boxes]) >>> inbatch.append(item) >>> out_batch = padded_collate(inbatch) >>> assert len(out_batch) == 2 >>> assert list(out_batch[0].shape) == [bsize, 3, 8, 8] >>> #assert list(out_batch[1][0].shape) == [bsize, 0, 4] >>> assert list(out_batch[1][0].shape) in [[0], []] # torch .3 a .4
Example
>>> inbatch = [torch.rand(4, 4), torch.rand(8, 4), >>> torch.rand(0, 4), torch.rand(3, 4), >>> torch.rand(0, 4), torch.rand(1, 4)] >>> out_batch = padded_collate(inbatch) >>> assert list(out_batch.shape) == [6, 8, 4]