-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why T>=S constraint? #20
Comments
As you have mentioned, that is for regular RNN-T. The version we are using is not regular. It has the same condition as CTC training, i.e., S <= T. |
Here is the paper about fast_rnnt: |
Here is the code to filter data that don't satisfy # In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False |
Thanks for your fast reply. |
@BuaaAlban as you noted, this constraint is indeed not required for the "regular" RNNT topology. Only if you train with the "modified" topology, where you are constrained to emit exactly 1 symbol per time frame, will this constraint be required. We have a PR here (k2-fsa/k2#1149) to remove this constraint from k2. I will also make a similar PR for fast_rnnt. |
@desh2608 are you still planning to make this PR? This would be very useful for my work! |
@arkadyark sorry I forgot to actually push the changes. BTW, I believe Dan fixed some OOM issues in the pruned transducer loss in k2, which hasn't yet been merged in fast_rnnt. So you may want to make those changes yourself. |
Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there. |
Check k2-fsa/k2#1177 and k2-fsa/k2#1183 |
Ah yes. Arkady, it would be great if you could make a PR to fast_rnnt with those changes, I had forgotten about that. If not LMK, I'll ask someone here. |
I would love to contribute those back, but unfortunately there's a fairly involved open-source contribution process at my organization that would take a while, it'd probably be best to find someone else to do so. However, I did test this out locally, and re-ran the benchmarking at https://github.com/csukuangfj/transducer-loss-benchmarking - the results look really good, peak memory usage goes from 3820 all the way down to 1182 (!), and from 2647 to 835 when sorting utterances. Step time (on my hardware) went from 343k to 280k us. Pretty cool! Always gotta be careful with those torch.gathers. |
Hey @danpovey , just wanted to follow up - is anybody able to make those changes here? |
@pkufool could you please have a look at this? |
closed by #29 |
code
Why do we need this constraint? In a regular rnnt, normally the joint may emit many blank symbol, and in this condition, T>S. But it's also possilble that S>T, e.g. we emit at least one non-blank symbols for each encoder frames.
Actually I have met this
File "/rnnt_related/rnnt-mlperf-training/model_rnnt.py", line 203, in fast_joint simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 282, in rnnt_loss_simple px, py = get_rnnt_logprobs( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 149, in get_rnnt_logprobs assert T >= S, (T, S) AssertionError: (272, 274)
The text was updated successfully, but these errors were encountered: