Multi Class Cross Entropy loss function


#1

Hi @jakub_czakon,

I am trying to get use a multi-output cross entropy loss function for the DSTL dataset. I took a look at the Open Solution Mapping Challenge loss functions here:

def multiclass_segmentation_loss(output, target):
    target = target.squeeze(1).long()
    cross_entropy = nn.CrossEntropyLoss()
    return cross_entropy(output, target)


def cross_entropy(output, target, squeeze=False):
    if squeeze:
        target = target.squeeze(1)
    return F.nll_loss(output, target)

def multi_output_cross_entropy(outputs, targets):
    losses = []
    for output, target in zip(outputs, targets):
        loss = cross_entropy(output, target, squeeze=True)
        losses.append(loss)
    return sum(losses) / len(losses)

In my DSTL dataset generator, I generate a mask and add it to each channel, so I have a 10-channel mask.

In the U-Net model, I set the input parameter to 3 (RGB images only) and the output channels=10.

These are the last few lines where I generate the image and mask:

# mask generator
mask = self.mask_generator.mask(id=id_, height=h, width=w)

if mask is None:
    raise ValueError('Could not generate concatenated mask!')

# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W

image = resize(image, 256, 256).transpose((2, 0, 1))
mask = resize(mask, 256, 256).transpose((2, 0, 1))

image = torch.from_numpy(image).float()
mask = torch.from_numpy(mask).long()

return image, mask

When I use the multiclass_segmentation_loss function, I get the following error:

  File "/tool/python/conda/env/gis36/lib/python3.6/site-packages/torch/nn/functional.py", line 1334, in nll_loss
    return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
RuntimeError: invalid argument 1: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:14

If I use the multi_output_cross_entropy, I get the following error:

  File "/tool/python/conda/env/gis36/lib/python3.6/site-packages/torch/nn/functional.py", line 1341, in nll_loss
    out_size, target.size()))
ValueError: Expected target size (10, 256), got torch.Size([10, 256, 256])

I would like to do a per pixel cross entropy loss, for all pixels in the image, for all images in a batch.

Would you be able tell me what I should do to write this loss function and make sure that the input and target shapes match?


#2

I think you actually want to use vanila cross_entropy since you have just one output (10 classes though). Multioutput is for exotic situations with a fork-structured output.

So I would just go with cross entropy or weighted sum of cross entropy and soft dice.