Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose weight_decay in torch TFT estimator class #2603

Merged
merged 3 commits into from
Jan 30, 2023

Conversation

gorold
Copy link
Contributor

@gorold gorold commented Jan 29, 2023

Description of changes:
Expose weight_decay parameter and add default value (1e-8) to be more consistent with existing mx/torch models.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella added the enhancement New feature or request label Jan 29, 2023
@lostella lostella requested a review from shchur January 29, 2023 10:37
@lostella lostella added the pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch label Jan 30, 2023
@lostella lostella added this to the v0.12 milestone Jan 30, 2023
Copy link
Contributor

@shchur shchur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you!

@lostella lostella changed the title expose weight_decay in estimator class Expose weight_decay in torch TFT estimator class Jan 30, 2023
@lostella lostella enabled auto-merge (squash) January 30, 2023 10:22
@lostella lostella merged commit 682e1cc into awslabs:dev Jan 30, 2023
@lostella lostella added the torch This concerns the PyTorch side of GluonTS label Jan 30, 2023
lostella pushed a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
@gorold gorold deleted the tft-expose-weight-decay branch February 1, 2023 05:10
lostella added a commit that referenced this pull request Feb 2, 2023
* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Move NPTS back to `gluonts.model` (#2597)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix bug with static cardinalities in `PandasDataset` (#2599)

* Expose `weight_decay` in torch TFT estimator class (#2603)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* Fix incorrect import in `tsbench`, apply latest black (#2613)

* Allow ReduceLROnPlateau to track val_loss when validation set is available (#2614)

---------

Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: Gerald Woo <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
@lostella lostella removed the pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch label Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants