-
Notifications
You must be signed in to change notification settings - Fork 777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add assertion to split function ensuring valid windows #2587
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably miss a few tests here @MarcelK1102 do you want to add them? This module is covered by https://github.com/awslabs/gluonts/blob/dev/test/dataset/test_split.py
src/gluonts/dataset/split.py
Outdated
assert ( | ||
base + offset + prediction_length | ||
<= entry[FieldName.TARGET].shape[-1] | ||
), "Offset too short to generate windows" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rhs of the inequality is just label_slice.stop
, you could use that I guess. Also maybe in this case the message should be "Date is too early to generate windows"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad: "Date is too late to generate windows"
:-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected :)
if self.offset < 0 and offset_ >= 0: | ||
offset_ += len(entry) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was clearly a bug, but I'm wondering whether offset_ += entry[FieldName.TARGET].shape[-1]
is still needed here: if self.offset == -offset
then offset_ == 0
then the input_slice
below may be incorrect? It would be slice(None, 0)
which would be empty I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is not needed. If self.offset == -offset
and offset_ == 0
then the assertion fails since -offset_ < prediction_length
. If we would add offset_ += entry[FieldName.TARGET].shape[-1]
then the assertion would fail as well since -offset_ < prediction_length
again. I can't think of any scenario where it would change the outcome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I'm wondering though whether ensuring offset_
is always positive would render the logic simpler to read.
So,
if self.offset < 0:
offset_ += entry[FieldName.TARGET].shape[-1]
assert offset_ + prediction_length <= entry[FieldName.TARGET].shape[-1], "Offset too short"
(didn't check this ☝️ thoroughly but makes sense to me)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two inequalities are equal then. I tested it, it works. I updated it accordingly.
@lostella Sure, I will do that! |
# Conflicts: # src/gluonts/dataset/split.py
Looks good now! Let's add the tests and then this can be merged! |
* Fix: avoid automatic device detection via serialized tensors when deserializing. (#2576) * Make itertools Map/Filter dataclasses. (#2579) * serde: Fix encoding of dtypes. (#2586) * Add assertion to split function ensuring valid windows (#2587) * Ensure dtype on feat_time in torch DeepAR. (#2596) * Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598) * Fix version in requirements to comply with stricter setuptools. (#2604) Co-authored-by: Lorenzo Stella <[email protected]> * Add `gluonts.util.safe_extract` (#2606) Co-authored-by: Jasper <[email protected]> Co-authored-by: Lorenzo Stella <[email protected]> * fix requirements further * fix style * remove undesired change --------- Co-authored-by: Shubham Kapoor <[email protected]> Co-authored-by: Jasper <[email protected]> Co-authored-by: MarcelK1102 <[email protected]> Co-authored-by: Lorenzo Stella <[email protected]> Co-authored-by: Jasper <[email protected]>
* Add assertion to split function ensuring valid windows (#2587) * Ensure dtype on feat_time in torch DeepAR. (#2596) * Move NPTS back to `gluonts.model` (#2597) * Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598) * Fix bug with static cardinalities in `PandasDataset` (#2599) * Expose `weight_decay` in torch TFT estimator class (#2603) * Fix version in requirements to comply with stricter setuptools. (#2604) Co-authored-by: Lorenzo Stella <[email protected]> * Add `gluonts.util.safe_extract` (#2606) Co-authored-by: Jasper <[email protected]> Co-authored-by: Lorenzo Stella <[email protected]> * Fix incorrect import in `tsbench`, apply latest black (#2613) * Allow ReduceLROnPlateau to track val_loss when validation set is available (#2614) --------- Co-authored-by: MarcelK1102 <[email protected]> Co-authored-by: Jasper <[email protected]> Co-authored-by: Gerald Woo <[email protected]> Co-authored-by: Lorenzo Stella <[email protected]> Co-authored-by: Jasper <[email protected]>
Issue: #2577
Description of changes:
This PR addresses issue #2577 and added assertions for
OffsetSplitter
andDateSplitter
to ensure that the targets of the label have valid target lengths, i.e. matchprediction_length.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup