Using a Linear Loss Function to Weight Output Importances

4 days ago 3
ARTICLE AD BOX

I am trying to solve a machine learning problem where certain samples have more importance than others. My minimal reproducible example gives an example of such a problem, which I will restate here:

We have a list of training samples: sample_data. Each one has a corresponding weight: loss_weights. We are trying to train a machine-learning algorithm to output a tensor: output. If loss_weights[i] < 0, we aim to maximize output[i], and conversely, if loss_weights[i] > 0, we aim to minimize output[i]. However, it is more important to minimize/maximize items outputs[i] where the magnitude of loss_weights[i] is higher.

I thought that a very intuitive way to encode such a problem would be with a linear loss function Loss = sum(loss_weights * outputs) (element-wise product). And I attempted this with my minimal reproducible example:

import torch import torch.nn as nn import matplotlib.pyplot as plt model = nn.Sequential( nn.Linear(1, 100), nn.Sigmoid(), nn.Linear(100, 100), nn.Sigmoid(), nn.Linear(100, 1), nn.Sigmoid(), ) sample_data = torch.Tensor([ [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7] # dummy data ]) loss_weights = torch.Tensor([ [2], [-0.1], [0.5], [-1], [0.2], [1], [-0.4] # Loss = loss_weights * output ]) optim = torch.optim.Adam(model.parameters(), lr=1e-3) for epoch in range(1000): output = model(sample_data) optim.zero_grad() loss = torch.sum(loss_weights * output) print(torch.autograd.grad(loss, output, retain_graph=True)[0], "dLoss/dOutput") loss.backward() optim.step() print(loss) plt.plot(loss_weights, color="green") plt.plot(model(sample_data).detach().numpy(), color="orange") plt.show()

Above, for example, we care most about minimizing output[0] since loss_weights[0] has the highest magnitude. And we care least about minimizing output[1] since loss_weights[1] has the lowest magnitude.

However, even with such a large model and a small sample size (where it should be able to overfit), my model only ever outputs 0. I thought this could have been an issue with my code, but when I changed loss_weights[1] = -2, the model only output 1 (maximizing all outputs).

My running theory is that the highest loss_weights entry is dominating everything else, causing the model to output a single value rather than learning the input. Is this theory correct? If so, how can I solve this problem.

I realize that I could use a weighted average of MSE losses, but my real use-case doesn't really lend itself to an "only one correct answer" type of loss, which is why I tried an unorthodox one.

Read Entire Article