Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA decay is not updated when restarting from checkpoint #13

Open
jianganbai opened this issue Mar 21, 2025 · 2 comments
Open

EMA decay is not updated when restarting from checkpoint #13

jianganbai opened this issue Mar 21, 2025 · 2 comments
Labels
pending This problem is yet to be addressed

Comments

@jianganbai
Copy link

Thank you for your excellent work and for open-sourcing all these codes.

I've recently found that when restarting the pre-training task from a halfway checkpoint (e.g. checkpoint_last.pt), the EMA decay remains as the initial value and is not successfully updated in the model.

def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is not None and (
(self.num_updates == 0 and num_updates > 1)
or self.num_updates >= num_updates
):
pass
elif self.training and self.ema is not None:
ema_weight_decay = None
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay, weight_decay=ema_weight_decay)
if self.ema.get_decay() < 1:
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
self.num_updates = num_updates

It seems that the logic if on line 360 is not correct during the restarting procedure. When restarting, this function is called when loading the checkpoint, where self.num_updates=0 and num_updates is bigger than 1. Therefore, self.decay will remain as the initial value 0.9998, which is the value at the very beginning of the pre-training process.

I'm not sure how bad the consequences are, since the EMA decay will be corrected after the first batch. But it seems to me that restarting from a checkpoint is a bit worse than training without stopping after a few trials.

Hope this will help.

@jianganbai
Copy link
Author

By the way, I guess only the ema decay should be updated when loading the checkpoint. self.ema.step should not be called since it will be called when the first batch is inputted into the model.

@cwx-worst-one
Copy link
Owner

Thank you so much for pointing this out @jianganbai.
I'll look into the issue with EMA updates when resuming training from a checkpoint (might take some time as I'm currently tied up with other ddl 😢).
Really appreciate your findings—I'll keep the issue open until it's resolved.

@cwx-worst-one cwx-worst-one added the pending This problem is yet to be addressed label Mar 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

2 participants