##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Synchronized Cross-GPU Batch Normalization Module"""
import warnings
try:
from queue import Queue
except ImportError:
from Queue import Queue
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from ..utils.misc import EncodingDeprecationWarning
from ..functions import *
__all__ = ['DistSyncBatchNorm', 'SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
[docs]class DistSyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
sync: a boolean value that when set to ``True``, synchronize across
different gpus. Default: ``True``
activation : str
Name of the activation functions, one of: `leaky_relu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
>>> m = DistSyncBatchNorm(100)
>>> net = torch.nn.parallel.DistributedDataParallel(m)
>>> output = net(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, process_group=None):
super(DistSyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True, track_running_stats=True)
self.process_group = process_group
[docs] def forward(self, x):
need_sync = self.training or not self.track_running_stats
process_group = None
if need_sync:
process_group = torch.distributed.group.WORLD
if self.process_group:
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
need_sync = world_size > 1
# Resize the input to (B, C, -1).
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
#def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
y = dist_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
self.eps, self.momentum, self.training, process_group)
#_var = _exs - _ex ** 2
#running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
#running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
return y.view(input_shape)
[docs]class SyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
sync: a boolean value that when set to ``True``, synchronize across
different gpus. Default: ``True``
activation : str
Name of the activation functions, one of: `leaky_relu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> m = SyncBatchNorm(100)
>>> net = torch.nn.DataParallel(m)
>>> output = net(input)
>>> # for Inpace ABN
>>> ABN = partial(SyncBatchNorm, activation='leaky_relu', slope=0.01, sync=True, inplace=True)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none", slope=0.01,
inplace=True):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
self.activation = activation
self.inplace = False if activation == 'none' else inplace
#self.inplace = inplace
self.slope = slope
self.devices = list(range(torch.cuda.device_count()))
self.sync = sync if len(self.devices) > 1 else False
# Initialize queues
self.worker_ids = self.devices[1:]
self.master_queue = Queue(len(self.worker_ids))
self.worker_queues = [Queue(1) for _ in self.worker_ids]
# running_exs
#self.register_buffer('running_exs', torch.ones(num_features))
def _check_input_dim(self, x):
pass
[docs] def forward(self, x):
if not self.training:
return super().forward(x)
# Resize the input to (B, C, -1).
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
if x.get_device() == self.devices[0]:
# Master mode
extra = {
"is_master": True,
"master_queue": self.master_queue,
"worker_queues": self.worker_queues,
"worker_ids": self.worker_ids
}
else:
# Worker mode
extra = {
"is_master": False,
"master_queue": self.master_queue,
"worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
}
if self.inplace:
y, _, _ = inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
extra, self.sync, self.training, self.momentum, self.eps,
self.activation, self.slope)
return y.view(input_shape)
else:
y, _, _ = syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
extra, self.sync, self.training, self.momentum, self.eps,
self.activation, self.slope)
return y.view(input_shape)
[docs]class BatchNorm1d(SyncBatchNorm):
r"""
.. warning::
BatchNorm1d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm1d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm1d, self).__init__(*args, **kwargs)
[docs]class BatchNorm2d(SyncBatchNorm):
r"""
.. warning::
BatchNorm2d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm2d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm2d, self).__init__(*args, **kwargs)
[docs]class BatchNorm3d(SyncBatchNorm):
r"""
.. warning::
BatchNorm3d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm3d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm3d, self).__init__(*args, **kwargs)