Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix reduction mistake in SpectralConvergenceLoss #75

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

renared
Copy link

@renared renared commented May 22, 2024

I noticed that when evaluating the STFT loss over my validation dataset, I obtained different results in function of the batch size. I could isolate the cause to be the spectral convergence term, then came across the comment by @egaznep in issue #69. It does not make sense to average the denominator over all dimensions including the batch dimension, so I believe their suggestion should be used instead.

This snippet shows the difference:

import torch
from auraloss.freq import STFTLoss

batches = [(torch.randn(4, 1, 16384), torch.randn(4, 1, 16384)) for i in range(1024)]
batchall = tuple(torch.concat(u, dim=0) for u in zip(*batches))

print("with spectral convergence enabled")
loss = STFTLoss()
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))
print("over full dataset:", loss(*batchall))

print("with spectral convergence disabled")
loss = STFTLoss(w_sc=0)
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))
print("over full dataset:", loss(*batchall))

Before:

with spectral convergence enabled
mean of losses: tensor(1.3511)
over full dataset: tensor(1.3493)
with spectral convergence disabled
mean of losses: tensor(0.6950)
over full dataset: tensor(0.6950)

After:

with spectral convergence enabled
mean of losses: tensor(1.3726)
over full dataset: tensor(1.3726)
with spectral convergence disabled
mean of losses: tensor(0.7095)
over full dataset: tensor(0.7095)

the denominator was averaged over all dimensions including the batch dimension, see comment by @egaznep in csteinmetz1#69
@cpvlordelo
Copy link

I just stumbled on the exact same problem. Is there any plans on merging this fix? Ping @csteinmetz1?

@@ -16,7 +15,7 @@ def __init__(self):
super(SpectralConvergenceLoss, self).__init__()

def forward(self, x_mag, y_mag):
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)
Copy link

@cpvlordelo cpvlordelo Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)
return (torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2)) / torch.norm(y_mag, p="fro", dim=(-1, -2))).mean()

Since you removed the reduction, this is now returning a multi-dimensional tensor. It does work with STFTLoss because the reduction is done inside of it as you can see here, but if you instantiate SpectralConvergenceLoss, on the other hand, then your example code there will crash.

import torch
from auraloss.freq import SpectralConvergenceLoss

batches = [(torch.randn(4, 1, 16384), torch.randn(4, 1, 16384)) for i in range(1024)]
batchall = tuple(torch.concat(u, dim=0) for u in zip(*batches))

loss = SpectralConvergenceLoss()
print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))

Before:

Shape of Spectral Convergence Loss over full dataset: torch.Size([])
mean of losses: tensor(1.4144)

After:

Shape of Spectral Convergence Loss full dataset: torch.Size([4096, 1, 1])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-45-952d11dbfe6a>](https://localhost:8080/#) in <cell line: 0>()
     23 print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
---> 24 print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))

ValueError: only one element tensors can be converted to Python scalars

This is just a suggestion that will always perform the reduction as mean.

But an even better option, in my opinion, would be to add a new string argument reduction as part of init and call apply_reduction inside this forward method in a similar way done in STFTLoss code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants