🎰 Deterministic `flash-attn`
Sam Foreman 2024-06-17
[NOTE]: For additional details, refer to the W&B Report.
Simple tests to confirm the loss is exactly reproducible across independent runs (when launched with the same seed).
-
In particular, we set:
output = flash_attn_func(q, k, v, None, self.causal, deterministic=True)in all the
flash_attn_func(...)calls frommegatron/model/transformer.py -
All experiments ran on Polaris @ ALCF, using:
machine: Polaris args.zero_stage: 1 args.num_layers: 32 args.micro_batch_size: 1 args.optimizer: "adamw" args.use_flash_attn: true env.DFL_STEM: "books" env.GRAD_ACC_STEPS: 8 env.WORLD_SIZE: 8
Figure 1: Plot of the loss curve for 3 independent runs with
deterministic=True