I wanted to revisit my HeAR model distillation because I was under a time constraint for the hackathon. I could not fully address all aspects of my training with the attention they deserved. Specifically, after the first full run, I noted many questions and concerns in my write-up. You can read about the initial conception and work here: https://uror.io/posts/heardi-tillationdatalake
After addressing the problems, the new model is equivalent of the teacher within noise (98% student/teacher)
The problems:
EMA Batch - LR Scheduler: I used a LR scheduler which adapted my learning rate based on the critical batch size. Literature supports LR - Batch dynamics around critical batch sizing. Specifically, you can either increase the batch to improve training dynamics (https://arxiv.org/abs/1711.00489) or as is more traditional, decrease the learning rate. I did not use either Cosine or WSD in the hackathon run.
Warmup: I used no warmup during training.
LR determination: I used a learning rate that was determined from grid search. Literature supports a Hessian determined learning rate (https://arxiv.org/abs/1708.07120).
Validation split: In my first run there were oscillations in my validation. I wanted to determine if these were data artifacts or instability artifacts.
Speed: My training was slow! I wanted faster training.
Datalake Workers: The workers would toggle off and on rapidly trying to match the consumption. My base workers were fine but I required a fractional number of workers more than the minimum to match the consumption. The toggling was an artifact from improper hysteresis of my workers.
Data (Datalake Workers, Validation Split)
I started upstream of the model for my work. The orchestrator streamed all my data to one lake then pulled off a one time validation split. There was no decay split. Now, I pull a stream into an intermediate lake which then selects decay, validation or train split. Each split has its own limitations and data caps. This streaming setup prevents leakage across splits.
When the validation or decay split hits their cap, I randomly begin a replacement process so that these splits remain up to date with the data. In the old datalake, I assumed that the stream of data was well mixed. This was not the case. Updating the validation allowed me to track representative values.
I reduced worker churn by retiring unneeded workers only at chunk boundaries. Each worker would write a chunk (group of audio clips) to a data lake. So if a worker activated then it would write a full chunk instead of deactivating. Previously, I had the worker activation tied to write rates which lead to rapid activation - deactivation cycles. Additionally, I updated from clip accounting to byte accounting. In the prior version I was inconsistent even across script args.
In the most recent training run I no longer had faulty workers nor oscillations in my validation split. This confirms that my possible instability was a data artifact and not a model instability artifact. I processed a total of 3.1 TB of data during the stable phase. The streaming performed as intended across the full training run.
Learning Rate Scheduling (Warmup, Batch - LR Dynamics, LR Determination)
I want my training to be adaptive to the model and the data. The practice of grid searching learning rate should die. Here I will address what I changed with brief reasoning. In another post, I will fully address the history and evolution of batches and learning rates.
The largest change was the introduction of the warmup. I implemented something akin to warmup-stable-decay (WSD) scheduling. Traditionally, the stable in WSD would be a non-changing learning rate. I pegged my learning rate proportional to the critical batch size similar to the prior run but I used a square root association.
Initial learning rate and batch size were determined in the warmup phase. I warmed up for 1000 steps. The first half I tracked the hessian sharpness to determine an initial learning rate for the stable phase. The learning rate during warmup started very low then climbed to a sharpness defined target as my sharpness stabilized. Each step was defined by reaching this target considering the entire warmup period. In the second half of training, I froze the target learning rate to prevent a dynamic target. The batch was not varied throughout warmup but I tracked the critical batch size as derived from the gradient noise scale. This was the batch size that was used in the stable and decay phase.
Warmup caused my critical batch size to decrease compared to prior runs. I also ended with a larger learning rate than what was determined in my prior distillation. The sharpness derived learning rate is a noisy process. I believe that the selection criteria can be improved further. However, this process resulted in a batch size and learning rate which were within a couple multiples of my prior selected values. In the future, I intend to use a larger batch size during the warmup to have a better estimate of these values.
During my ‘stable’ phase, I used what would be considered ‘stable noise’ as opposed to a stable learning rate. In practice, my critical batch size initially continued to increase before stabilizing. This stabilization led to a quasi-stable learning rate. My stable noise was effectively a stable learning rate for the majority of training. The final learning rate sat around 80% of the initially selected rate. My training remained stable with good loss curves.
The decay phase was 10% of my stable phase in both training examples and steps. I used a linear ramp to zero for the learning rate. As you will be able to see in the results section, the validation and decay agreed in terms of data mix. However, the train still had a different mix as indicated by the loss curve jump. Overall, this did not prevent my student from matching my teacher performance. My final model is the best I’ve trained.
Speed (Speed)
The final consideration was making my model faster. I implemented diagnostic tracking with a console summary. During the initial testing phase I would get readouts like the following:
`perf@step=... step_ms=... teacher_ms=... student_ms=... input_ms=... aux_ms=... bottleneck=teacher_bound`
There were a couple obvious changes like compiling the pytorch models during the stable/decay phase, fusing adamW and increasing the teacher batches so the student could pull from a cache. These changes drastically improved my throughput.
My final batch was smaller (and as such finished faster). My steps per second increased from 0.82 to 4.66. The smaller batch does not account for the entirety of this speedup. Let’s consider the speed up is linear so I get a 3x speedup just from the batch decrease (it is not and likely closer to 2x). Then I would be sitting in a situation where my changes improved throughput 89.4%. Assuming the more realistic 2x speedup from my batch size decrease would yield us ~180% speedup.
The training is now faster and I could push it further by increasing batch over my critical batch size but that yields diminishing speed returns.
Results
| Dataset | Metric | Full Run | Full Run (Full Clip) | Stable Phase End (Full Clip) | Stable Phase End | |||
|---|---|---|---|---|---|---|---|---|
| FSD50K + FluSense | mAP | 0.645 | 0.588 | 0.438 | 0.511 |
| Task | Metric | Full Run | Full Run (Full Clip) | Stable Phase End (Full Clip) | Stable Phase End | |||
|---|---|---|---|---|---|---|---|---|
| Breathing | AP | 0.298 | 0.247 | 0.177 | 0.182 | |||
| Cough | AP | 0.665 | 0.442 | 0.292 | 0.618 | |||
| Laughter | AP | 0.604 | 0.480 | 0.310 | 0.481 | |||
| Sneeze | AP | 0.759 | 0.504 | 0.264 | 0.257 | |||
| Speech | AP | 0.712 | 0.661 | 0.581 | 0.614 |
| Task | Metric | Full Run | Full Run (Full Clip) | Stable Phase End (Full Clip) | Stable Phase End | |||
|---|---|---|---|---|---|---|---|---|
| Breathing | AP | 0.335 | 0.375 | 0.262 | 0.227 | |||
| Cough | AP | 0.896 | 0.899 | 0.820 | 0.856 | |||
| Gasp | AP | 0.457 | 0.492 | 0.382 | 0.331 | |||
| Sneeze | AP | 0.735 | 0.697 | 0.431 | 0.534 | |||
| Sniffle | AP | 0.849 | 0.769 | 0.578 | 0.691 | |||
| Speech | AP | 0.906 | 0.907 | 0.856 | 0.860 | |||
| Throat-Clearing | AP | 0.371 | 0.382 | 0.308 | 0.222 |
Final student model: mAP = 0.645 [0.526, 0.764]
HeAR (Teacher): mAP = 0.658 [0.550, 0.766]
The model after stable is significantly worse than at the end of my full run! The length generalization is still significantly worse than the model from my hackathon run. However, my model after the decay phase is effectively as good as the full sized teacher model. That is parity with a model 5x its size. The context extension is also equivalent (without the need for position encoding). Let’s look at some of the training dynamics:
My learning rate was fairly consistent at 80% of my chosen rate after the critical batch driven decrease early in training. I can see the decay phase clearly (morning-monkey). Both of the sun runs were from my original hackathon run where I only used a CBS driven decay with a larger batch size (due to the lack of warmup).
I am using just the MSE loss because the contrastive losses between training runs is not comparable. The smaller batch size of the most recent run makes the contrastive objective easier. That pollutes the full loss. MSE loss is more comparable. It can be seen that I trained for 440k steps as opposed to 200k. This is 2.2x longer in steps but shorter in wall clock time. My most recent run is also shorter in terms of data. I only used ~85% as much data across the entire run. I expected a higher CBS which did not materialize on the final run. This is partly due to the noise in estimating my CBS metric. I pre-selected the step length which led to the data lag.
I can see that my selected optimal batch was noisy but stabilized around 75 examples. I did not increase from my batch size of 48. Instead, the learning rate was modulated in accordance with the CBS to retain efficient training. In the future, I would train multiples above CBS. I also believe the lower batch sizes during warmup increased the noise in estimating CBS. I will increase the warmup batch size in the future.
My train MSE loss looks similar to the validation except the decay data appears to initially have a higher loss. This is likely driven by dataset drift. The decay and validation set were updated as I streamed data to account for drift. Since training consumed the majority of the data, it did not have a mechanism to account for that drift. Despite this, I still ended with the lowest training loss seen after my decay.
Discussion
Most of the problems I outlined in my initial training run for the hackathon have been addressed. The newest model is now as performant as the teacher and achieves pareto frontier status. There are still a number of peculiarities like my decreased context extension and the data drift from a pure stream. These can be dissected and addressed in future experiments.
The decreased context extension is interesting. I did train at a higher learning rate with less data. I would be interested to see how this evolves over the course of a longer training run. As is evident, I was equivalently as bad before the decay on a proportional basis between extended and trained clip length. The decay phase led to a large increase in my capacity (which is expected). Yet, it did not close the gap on context extension compared to my hackathon model. The higher learning rate I used wasn’t intentional per se but was the result of my learning rate selection mechanism. One issue on the warmup may be the limited batch size leading to increased noise in my CBS and learning rate selection methods.
Since the wall clock time is significantly faster, I should train this model for longer. In addition, I had space on the GPU to increase the batch size. After this experiment I am more convinced than ever that there is a limited need to change the learning rate or batch size during the stable portion of training. For instance, I saw that my learning rate was steady in mean value (if incredibly variable). Simplifying the stable phase has the added benefit of decreasing the computational burden of training. It can be even faster. There are studies which argue against this but I believe they are explainable in relation to the CBS point (https://arxiv.org/abs/2408.13359v2).
The data drift is still an issue in this run. The model ends up powerful but it could be better with a better selection of data. I look forward to addressing this in the future. Both the mix, train mix filtering and decay phase filtering are topics for future work. For instance, many training runs select the best samples to anneal/decay at the end of training. A better training mix likely yields better benchmarks. Additionally, including context extension within my decay will likely help generalize context lengths better.
The larger intent of this project was to flesh out learning rate scheduling and learning rate setting. I believe I have met both of those goals. I look forward to improving the selection methods and testing them on different modalities and data mixes.