import threading
import numpy as np
import torch

__all__ = ['accuracy', 'get_pixacc_miou',
           'SegmentationMetric', 'batch_intersection_union', 'batch_pix_accuracy',
           'pixel_accuracy', 'intersection_and_union']

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def get_pixacc_miou(total_correct, total_label, total_inter, total_union):
    pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
    IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
    mIoU = IoU.mean()
    return pixAcc, mIoU

[docs]class SegmentationMetric(object): """Computes pixAcc and mIoU metric scroes """ def __init__(self, nclass): self.nclass = nclass self.lock = threading.Lock() self.reset() def update(self, labels, preds): def evaluate_worker(self, label, pred): correct, labeled = batch_pix_accuracy( pred, label) inter, union = batch_intersection_union( pred, label, self.nclass) with self.lock: self.total_correct += correct self.total_label += labeled self.total_inter += inter self.total_union += union return if isinstance(preds, torch.Tensor): evaluate_worker(self, labels, preds) elif isinstance(preds, (list, tuple)): threads = [threading.Thread(target=evaluate_worker, args=(self, label, pred), ) for (label, pred) in zip(labels, preds)] for thread in threads: thread.start() for thread in threads: thread.join() else: raise NotImplemented def get_all(self): return self.total_correct, self.total_label, self.total_inter, self.total_union def get(self): return get_pixacc_miou(self.total_correct, self.total_label, self.total_inter, self.total_union) def reset(self): self.total_inter = 0 self.total_union = 0 self.total_correct = 0 self.total_label = 0 return
[docs]def batch_pix_accuracy(output, target): """Batch Pixel Accuracy Args: predict: input 4D tensor target: label 3D tensor """ _, predict = torch.max(output, 1) predict = predict.cpu().numpy().astype('int64') + 1 target = target.cpu().numpy().astype('int64') + 1 pixel_labeled = np.sum(target > 0) pixel_correct = np.sum((predict == target)*(target > 0)) assert pixel_correct <= pixel_labeled, \ "Correct area should be smaller than Labeled" return pixel_correct, pixel_labeled
[docs]def batch_intersection_union(output, target, nclass): """Batch Intersection of Union Args: predict: input 4D tensor target: label 3D tensor nclass: number of categories (int) """ _, predict = torch.max(output, 1) mini = 1 maxi = nclass nbins = nclass predict = predict.cpu().numpy().astype('int64') + 1 target = target.cpu().numpy().astype('int64') + 1 predict = predict * (target > 0).astype(predict.dtype) intersection = predict * (predict == target) # areas of intersection and union area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) area_union = area_pred + area_lab - area_inter assert (area_inter <= area_union).all(), \ "Intersection area should be smaller than Union area" return area_inter, area_union
# ref def pixel_accuracy(im_pred, im_lab): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) # Remove classes from unlabeled pixels in gt image. # We should not penalize detections in unlabeled portions of the image. pixel_labeled = np.sum(im_lab > 0) pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) #pixel_accuracy = 1.0 * pixel_correct / pixel_labeled return pixel_correct, pixel_labeled def intersection_and_union(im_pred, im_lab, num_class): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) # Remove classes from unlabeled pixels in gt image. im_pred = im_pred * (im_lab > 0) # Compute area intersection: intersection = im_pred * (im_pred == im_lab) area_inter, _ = np.histogram(intersection, bins=num_class-1, range=(1, num_class - 1)) # Compute area union: area_pred, _ = np.histogram(im_pred, bins=num_class-1, range=(1, num_class - 1)) area_lab, _ = np.histogram(im_lab, bins=num_class-1, range=(1, num_class - 1)) area_union = area_pred + area_lab - area_inter return area_inter, area_union