-
Notifications
You must be signed in to change notification settings - Fork 777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LSTNet #596
Add LSTNet #596
Conversation
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.72% 83.89% +0.17%
==========================================
Files 181 184 +3
Lines 10326 10442 +116
==========================================
+ Hits 8645 8760 +115
- Misses 1681 1682 +1
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.72% 83.89% +0.17%
==========================================
Files 181 184 +3
Lines 10326 10442 +116
==========================================
+ Hits 8645 8760 +115
- Misses 1681 1682 +1
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.73% 83.91% +0.17%
==========================================
Files 180 183 +3
Lines 10281 10412 +131
==========================================
+ Hits 8609 8737 +128
- Misses 1672 1675 +3
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.73% 83.91% +0.17%
==========================================
Files 180 183 +3
Lines 10281 10412 +131
==========================================
+ Hits 8609 8737 +128
- Misses 1672 1675 +3
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.72% 83.91% +0.18%
==========================================
Files 178 183 +5
Lines 10279 10412 +133
==========================================
+ Hits 8606 8737 +131
- Misses 1673 1675 +2
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.72% 83.90% +0.18%
==========================================
Files 178 181 +3
Lines 10279 10420 +141
==========================================
+ Hits 8606 8743 +137
- Misses 1673 1677 +4
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
=========================================
+ Coverage 83.72% 83.9% +0.18%
=========================================
Files 178 181 +3
Lines 10279 10420 +141
=========================================
+ Hits 8606 8743 +137
- Misses 1673 1677 +4
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.72% 83.90% +0.18%
==========================================
Files 178 181 +3
Lines 10279 10420 +141
==========================================
+ Hits 8606 8743 +137
- Misses 1673 1677 +4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you! See some comments and questions inline
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.94% 84.11% +0.17%
==========================================
Files 178 181 +3
Lines 10337 10478 +141
==========================================
+ Hits 8677 8814 +137
- Misses 1660 1664 +4
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 83.94% 84.11% +0.17%
==========================================
Files 178 181 +3
Lines 10337 10478 +141
==========================================
+ Hits 8677 8814 +137
- Misses 1660 1664 +4
|
@lostella thanks! all the comments have been addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the settings for this model are slightly inconsistent with the other ones: in GluonTS, prediction_length
indicates the time length of the forecast that the predictor produces; in this case, the length of the forecasts is fixed to 1, and prediction_length
effectively controls the lead time, i.e. how long past the conditioning range (the “past target”) the forecast starts.
I believe that I ncluding the model as it is now could create confusion, for example when evaluating its performance on standard datasets: there, all other models are really tested against a prediction interval of length > 1. We should actually implement a very generic test for all predictors, that makes sure that all common concepts (such as freq
or prediction_length
) are uniformly adopted, in which case this model wouldn’t fit the story well.
I see two ways to address this issue, which could actually be implemented together:
-
Adjusting the LSTNet network in such a way that forecasts for the whole prediction interval are produced. I’m thinking maybe the number of units in
ar_fc
could be increased from 1 to reflect the number of predicted points? Other parts of the network may need to be adjusted similarly, what do you think @ehsanmok? -
Introducing an explicit
lead_time
property in predictors, which could default to 0 for models not exposing this customization (being very careful to off-by-one errors in setting the convention here), and instead be set to the appropriate value here in the LSTNet estimator. This is something that GluonTS could use in general, regardless of LSTNet, and should be addressed in a separate PR. I summon the wise opinions of @vafl and @jaheba on this matter.
@lostella thanks for bringing up the difference! yes, I was thinking to add two modes basically, one with |
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
=========================================
+ Coverage 84.63% 84.83% +0.2%
=========================================
Files 178 181 +3
Lines 10401 10565 +164
=========================================
+ Hits 8803 8963 +160
- Misses 1598 1602 +4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some comments about the network implementation, see inline. I think once these are addressed, everything should be fine. Maybe you could share some results with this estimator, to make sure it's learning correctly and giving meaningful predictions?
assert ( | ||
fct.start_date | ||
== pd.date_range( | ||
start=str(test_ds["start"]), | ||
periods=test_ds["target"].shape[1], # number of test periods | ||
freq=freq, | ||
closed="right", | ||
)[-(horizon or prediction_length)] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the right assertion in case horizon
is set: if horizon
or prediction_length
are set to p
, then the forecast contains the following time steps
horizon=p => y_{t+p}
prediction_length=p => y_{t+1}, ..., y_{t+p}
So, when horizon
is set, the fct.start_date
should be p-1
time steps later, while according to the assertion it's the same in both cases.
However, at the moment there might be no easy way to get the correct fct.start_date
out of the predictor in that case: I've opened issue #677 in this regard (in particular, the ForecastGenerator
needs to be modified). I think it's fine for now to have the prediction_length
case doing the right thing (and the test looks OK for that), and then fix the horizon
case once #677 is addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! seems the correct start_date
doesn't seem be available. Though in test case, it's actually p=2
and the start_date
is one before the last time.
Add initial test Make smoke test pass Fix reshape bug and parameters validation Add more tests Fix unregistered hybrid rnn layer Add docs LSTNet scaling support Scaling and doc Fix dtype
Cleanup
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 84.63% 84.83% +0.20%
==========================================
Files 178 181 +3
Lines 10401 10565 +164
==========================================
+ Hits 8803 8963 +160
- Misses 1598 1602 +4
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 84.63% 84.83% +0.19%
==========================================
Files 178 181 +3
Lines 10401 10561 +160
==========================================
+ Hits 8803 8959 +156
- Misses 1598 1602 +4
|
Codecov Report
@@ Coverage Diff @@
## master #596 +/- ##
==========================================
+ Coverage 84.63% 84.83% +0.19%
==========================================
Files 178 181 +3
Lines 10401 10561 +160
==========================================
+ Hits 8803 8959 +156
- Misses 1598 1602 +4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! We will probably want to revisit how the horizon
case is treated once #677 is settled (see this comment), @ehsanmok I may ping you about that to review the changes
past_observed_values | ||
Tensor of shape (batch_size, num_series, context_length) | ||
future_target | ||
Tensor of shape (batch_size, num_series, 1) if `horizon` was specified |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the time length 1 in general when horizon
is set? In the estimator, the “future length” in the instance splitter is set to horizon
in this case, so it will be > 1 if horizon
is. Do I understand this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when using horizon
the since we're predicting a point not a sequence then future is 1 too to compute the loss and for prediction_length
it's a sequence as usual.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm expecting instance splitter to respect it, so you're saying it may not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think so: I would set horizon to 2 and put a breakpoint here to check the shape of future_target, I’m petty sure the time length will be 2, not 1
Description of changes:
An implemetation of LSTNet.
gluonts/model/lstnet/
Note that I haven't replicated all of the paper's results as there is quite a big range of hyper-parameters that paper used for tuning and no final hyper-parameters were given!
TODO: Temporal Attention layer.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.