Source code for encoding.utils.metrics
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading
import numpy as np
import torch
__all__ = ['accuracy', 'get_pixacc_miou',
'SegmentationMetric', 'batch_intersection_union', 'batch_pix_accuracy',
'pixel_accuracy', 'intersection_and_union']
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def get_pixacc_miou(total_correct, total_label, total_inter, total_union):
pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
mIoU = IoU.mean()
return pixAcc, mIoU
[docs]class SegmentationMetric(object):
"""Computes pixAcc and mIoU metric scroes
"""
def __init__(self, nclass):
self.nclass = nclass
self.lock = threading.Lock()
self.reset()
def update(self, labels, preds):
def evaluate_worker(self, label, pred):
correct, labeled = batch_pix_accuracy(
pred, label)
inter, union = batch_intersection_union(
pred, label, self.nclass)
with self.lock:
self.total_correct += correct
self.total_label += labeled
self.total_inter += inter
self.total_union += union
return
if isinstance(preds, torch.Tensor):
evaluate_worker(self, labels, preds)
elif isinstance(preds, (list, tuple)):
threads = [threading.Thread(target=evaluate_worker,
args=(self, label, pred),
)
for (label, pred) in zip(labels, preds)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
raise NotImplemented
def get_all(self):
return self.total_correct, self.total_label, self.total_inter, self.total_union
def get(self):
return get_pixacc_miou(self.total_correct, self.total_label, self.total_inter, self.total_union)
def reset(self):
self.total_inter = 0
self.total_union = 0
self.total_correct = 0
self.total_label = 0
return
[docs]def batch_pix_accuracy(output, target):
"""Batch Pixel Accuracy
Args:
predict: input 4D tensor
target: label 3D tensor
"""
_, predict = torch.max(output, 1)
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target)*(target > 0))
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
[docs]def batch_intersection_union(output, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 4D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(output, 1)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter, area_union
# ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
def pixel_accuracy(im_pred, im_lab):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
pixel_labeled = np.sum(im_lab > 0)
pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
#pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
return pixel_correct, pixel_labeled
def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
im_pred = im_pred * (im_lab > 0)
# Compute area intersection:
intersection = im_pred * (im_pred == im_lab)
area_inter, _ = np.histogram(intersection, bins=num_class-1,
range=(1, num_class - 1))
# Compute area union:
area_pred, _ = np.histogram(im_pred, bins=num_class-1,
range=(1, num_class - 1))
area_lab, _ = np.histogram(im_lab, bins=num_class-1,
range=(1, num_class - 1))
area_union = area_pred + area_lab - area_inter
return area_inter, area_union