-
Notifications
You must be signed in to change notification settings - Fork 777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add IQN implementation #1784
Add IQN implementation #1784
Conversation
scale because of the `mean` gets squeezed prematurely and thus the keepdim has no effect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kashif looks good to me, I would ask just for a few minor changes (see also the comment about pinball loss)
# penalize by tau for over-predicting | ||
# and by 1-tau for under-predicting | ||
return (self.taus - (self.outputs < value).float()) * ( | ||
self.outputs - value | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, got confused: it should be
return (self.taus - (value < self.outputs).float()) * (value - self.outputs)
Proof:
import torch
import numpy as np
def pinball_loss(obs, pred, alpha):
return (alpha - (obs < pred).float()) * (obs - pred)
def subgradient_method(f, p, maxit):
for k in range(maxit):
v = f(p)
v.backward()
with torch.no_grad():
p -= 1/(k+1) * p.grad
p.grad.zero_()
return p
data = np.random.normal(size=(1000,))
f = lambda q: pinball_loss(torch.from_numpy(data), q, 0.1).mean()
q = subgradient_method(f, torch.tensor(0.0, requires_grad=True), 1000)
print(q)
awesome thanks @lostella fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚢 thanks @kashif!
I thank you! |
Issue #, if available:
Description of changes:
Implemented IQN distribution output head from the paper https://arxiv.org/abs/2107.03743
fixes #1643
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup