-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Conversation
@epwalsh, you can look at this now, while I'm fixing tests. What do we need to change to make fairscale work? |
In a meeting now but I'll take a look afterwards |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are great improvements, but they don't really change the story with FairScale.
One thing that's missing is synchronization across distributed workers when gathering model and training state, since collecting the state associated with sharded parameters requires a distributed gather
operation (each worker needs to send its shard of the data to the main process).
Another issue is that the optimizer state actually has to be collected through the FullyShardedDataParallel
model wrapper (gather_full_optim_state
).
allennlp/training/checkpointer.py
Outdated
save_completed_epochs: bool = True, | ||
save_every_num_seconds: Optional[int] = None, | ||
save_every_num_batches: Optional[int] = None, | ||
keep_most_recent_by_count: Optional[int] = 2, | ||
keep_most_recent_by_age: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I hated the old names 💯
GradientDescentTrainer, | ||
) | ||
from allennlp.training.trainer import Trainer | ||
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. I've been wanting to move this to it's own file for a file.
def state_dict(self) -> Dict[str, Any]: | ||
return {} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
✅
Deep in the throes of fixing all the tests, I'm wondering if I should have fixed this backwards. Saving and restoring in the middle of an epoch was added to the checkpointer, but it's completely unsupported by any other part of the system. This is essentially a new piece of functionality. |
Tests pass locally. I'm still fighting with mypy and the models repo. We might have to retrain some stuff (or at least patch the model configs), because the But overall, this is ready to review. |
@@ -152,6 +152,7 @@ jobs: | |||
run: | | |||
git clone https://github.com/allenai/allennlp-models.git | |||
cd allennlp-models | |||
git checkout Checkpointing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
git checkout Checkpointing |
This will have to be removed before merging
You can also review this one commit at a time. I kept the commits pretty clean and self contained. That'll let you skip the big copy of |
We should do a minor version bump after this. It changes some public APIs. |
@epwalsh, this is ready for a real review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I just left a few comments.
extra_copy_of_weights_just_for_mypy = Path(weights) | ||
if extra_copy_of_weights_just_for_mypy.is_absolute(): | ||
weights_file = extra_copy_of_weights_just_for_mypy | ||
else: | ||
weights_file = Path(serialization_dir) / extra_copy_of_weights_just_for_mypy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little confusing. How about just use typing.cast
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serialization_dir
can be a str
at the time. It's not just to let mypy know what it is.
save_completed_epochs : `bool`, (default=`True`) | ||
Saves model and trainer state at the end of each completed epoch. | ||
save_every_num_seconds : `int`, optional (default=`None`) | ||
If set, makes sure we never go longer than this number of seconds between saving a model. | ||
save_every_num_batches : `int`, optional (default=`None`) | ||
If set, makes sure we never go longer than this number of batches between saving a model. | ||
keep_most_recent_by_count : `int`, optional (default=`2`) | ||
Sets the number of model checkpoints to keep on disk. If both `keep_most_recent_by_count` and | ||
`keep_most_recent_by_age` are set, we'll keep checkpoints that satisfy either criterion. | ||
If both are `None`, we keep all checkpoints. | ||
keep_most_recent_by_age : `int`, optional (default=`None`) | ||
Sets the number of seconds we'll keep a checkpoint before deleting it. If both | ||
`keep_most_recent_by_count` and `keep_most_recent_by_age` are set, we'll keep checkpoints | ||
that satisfy either criterion. If both are `None`, we keep all checkpoints. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this is much more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it breaks backwards compatibility. Worth it, I think, but not great.
CHANGELOG.md
Outdated
@@ -40,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 | |||
|
|||
- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. | |||
- Fixed documentation for `GradientDescentTrainer.cuda_device`. | |||
- Re-starting a training run from a checkpoint in the middle of an epoch now works correctly. | |||
- When using the "moving average" weights smoothing feature of the trainer, training checkpoints would also get smoothed, with strange results for resuming a training job. This has been fixed. | |||
- When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. This ensures that any random number generators used by the reader or data loader are in the same state as they were the first time the training job ran. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds good, in theory, but there are probably other things that affect the random number generators used by the reader and data loader. I don't think we can guarantee the same order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I wrote it this way because in Quark it worked out that way. I had good enough control over the RNGs that it was deterministic.
In AllenNLP, we can't guarantee that none of the things we're skipping when restoring from a checkpoint (the forward()
method for example) modify the RNG state. I guess I'll say that this is an attempt to ensure deterministic randomness, but does not guarantee it. At the same time, we should encourage components to use their own RNG instead of using the global one, so they don't affect each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually quite bad if this doesn't work. If we don't guarantee the order of instances, and we stop training 10 times in the middle of an epoch and restart it, we might end up training on the same instance 10 times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 times even
* Removes unused variable * Formatting * Make sure we always restore the model's weights properly * Give TrainerCallbacks the ability to save and load state dicts * Give MovingAverage the ability to save and load state dicts * Do not set gradients to None * Typo * Remove unused variable * Typo * Entirely new checkpointing code * Formatting * Make mypy happy lol * Makes the no-op trainer work with the new checkpointer * Mark epochs as completed when they're skipped * Changelog * Fixes how we get the best weights after a training run * Mypy is annoying * Callback fixes * Fix the no op trainer * Simplify * Assorted checkpointer fixes * Mypy is now happy * Fixed all the tests except for one * Removed unused variable * Fix trainer restore logic * Fix test for trainer restore logic * Check the Checkpointing branch of the models repo * Help mypy along * Fixed finalizing logic * More mypy stuff * Update allennlp/training/checkpointer.py Co-authored-by: Pete <[email protected]> * Make weaker claims Co-authored-by: Pete <[email protected]>
GradientDescentTrainer
now lives in its own file. I had to do this to break a circular dependency betweenCheckpointer
andGradientDescentTrainer
.