"""
typhon.retrieval.qrnn.models.pytorch.unet
=========================================
This module provides an implementation of the UNet [unet]_
architecture.
.. [unet] O. Ronneberger, P. Fischer and T. Brox, "U-net: Convolutional networks
for biomedical image segmentation", Proc. Int. Conf. Med. Image Comput.
Comput.-Assist. Intervent. (MICCAI), pp. 234-241, 2015.
"""
import torch
from torch import nn
from typhon.retrieval.qrnn.models.pytorch.common import PytorchModel
class Layer(nn.Sequential):
"""
Basic building block of a UNet. Consists of a convolutional
layer followed by an activation layers and an optional batch
norm layer.
Args:
features_in(:code:`int`): Number of features of input
features_out(:code:`int`): Raw number of output features of the
layer not including skip connections.
batch_norm(:code:`bool`): Whether or not to include a batch norm
layer.
kernel_size(:code:`int`): Kernel size to use for the conv. layer.
activation(:code:`activation`): Activation to use after conv. layer.
skip_connection(:code:`bool`): Whether to include skip connections, i.e.
to include input in layer output.
"""
def __init__(
self,
features_in,
features_out,
kernel_size=3,
activation=nn.ReLU,
skip_connection=False,
):
self._features_in = features_in
self._features_out = features_out
self.skip_connection = skip_connection
if activation is not None:
modules = [
nn.ConstantPad2d(1, 0.0),
nn.Conv2d(features_in, features_out, kernel_size),
nn.BatchNorm2d(features_out),
activation(),
]
else:
modules = [
nn.ConstantPad2d(1, 0.0),
nn.Conv2d(features_in, features_out, kernel_size),
nn.BatchNorm2d(features_out),
]
super().__init__(*modules)
@property
def features_out(self):
"""
The number outgoing channels of the layer.
"""
if self.skip_connection:
return self._features_in + self._features_out
return self._features_out
def forward(self, x):
""" Forward input through layer. """
y = nn.Sequential.forward(self, x)
if self.skip_connection:
y = torch.cat([x, y], dim=1)
return y
class Block(nn.Sequential):
"""
A block bundles a set of layers.
"""
def __init__(
self,
features_in,
features_out,
depth=2,
batch_norm=True,
activation=nn.ReLU,
kernel_size=3,
skip_connection=None,
):
"""
Args:
features_in(:code:`int`): The number of input features of the block
features_out(:code:`int`): The number of output features of the block.
depth(:code:`int`): The number of layers of the block
activation(:code:`nn.Module`): Pytorch activation layer to
use. :code:`nn.ReLU` by default.
skip_connection(:code:`str`): Whether or not to insert skip
connections before all layers (:code:`"all"`) or just at
the end (:code:`"end"`).
"""
self._features_in = features_in
if skip_connection == "all":
skip_connection_layer = True
self.skip_connection = False
elif skip_connection == "end":
skip_connection_layer = False
self.skip_connection = True
else:
skip_connection_layer = False
self.skip_connection = False
layers = []
nf = features_in
for d in range(depth):
layers.append(
Layer(
nf,
features_out,
activation=activation,
batch_norm=batch_norm,
kernel_size=kernel_size,
skip_connection=skip_connection_layer,
)
)
nf = layers[-1].features_out
self._features_out = layers[-1].features_out
super().__init__(*layers)
@property
def features_out(self):
"""
The number outgoing channels of the layer.
"""
if self.skip_connection:
return self._features_in + self._features_out
else:
return self._features_out
def forward(self, x):
""" Forward input through layer. """
y = nn.Sequential.forward(self, x)
if self.skip_connection:
y = torch.cat([x, y], dim=1)
return y
class DownSampler(nn.Sequential):
"""
A downsampling block reduces the input resolution by applying max-pooling.
"""
def __init__(self):
modules = [nn.MaxPool2d(2)]
super().__init__(*modules)
class UpSampler(nn.Sequential):
"""
An upsampling block increases the input resolution by transposed convolution.
"""
def __init__(self, features_in, features_out):
modules = [
nn.ConvTranspose2d(
features_in, features_out, 3, padding=1, output_padding=1, stride=2
)
]
super().__init__(*modules)
[docs]class UNet(PytorchModel, nn.Module):
"""
Pytorch implementation of the UNet architecture for image segmentation.
"""
[docs] def __init__(
self, input_features, quantiles, n_features=32, n_levels=4, skip_connection=None
):
"""
Args:
input_features(``int``): The number of channels of the input image.
quantiles(``np.array``): Array containing the quantiles to predict.
n_features: The number of channels of the first convolution block.
n_level: The number of down-sampling steps.
skip_connection: Whether or not to include skip connections in
each block.
"""
nn.Module.__init__(self)
PytorchModel.__init__(self, input_features, quantiles)
# Down-sampling blocks
self.down_blocks = nn.ModuleList()
self.down_samplers = nn.ModuleList()
features_in = input_features
features_out = n_features
for i in range(n_levels - 1):
self.down_blocks.append(
Block(features_in, features_out, skip_connection=skip_connection)
)
self.down_samplers.append(DownSampler())
features_in = self.down_blocks[-1].features_out
features_out = features_out * 2
self.center_block = Block(
features_in, features_out, skip_connection=skip_connection
)
self.up_blocks = nn.ModuleList()
self.up_samplers = nn.ModuleList()
features_in = self.center_block.features_out
features_out = features_out // 2
for i in range(n_levels - 1):
self.up_samplers.append(UpSampler(features_in, features_out))
features_in = features_out + self.down_blocks[(-i - 1)].features_out
self.up_blocks.append(
Block(features_in, features_out, skip_connection=skip_connection)
)
features_out = features_out // 2
features_in = self.up_blocks[-1].features_out
self.head = nn.Sequential(
nn.Conv2d(features_in, features_in, 1),
nn.ReLU(),
nn.Conv2d(features_in, features_in, 1),
nn.ReLU(),
nn.Conv2d(features_in, quantiles.size, 1),
)
[docs] def forward(self, x):
""" Propagate input through layer. """
features = []
for (b, s) in zip(self.down_blocks, self.down_samplers):
x = b(x)
features.append(x)
x = s(x)
x = self.center_block(x)
for (b, u, f) in zip(self.up_blocks, self.up_samplers, features[::-1]):
x = u(x)
x = torch.cat([x, f], 1)
x = b(x)
self.features = features
return self.head(x)