[FSDP][torch.compile] accelerator.unwrap_model and trainer._save work incorrectly when FSDP + torch.compile
Issue Details
System Info
transformers 4.51.3 accelerate 1.6.0
Who can help?
@zach-huggingface @SunMarc
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
To use torch.compile, you need to either uninstall the kernels library or set the environment variable DISABLE_KERNEL_MAPPING to 1.
train.py
from typing import cast import torch from transformers import HfArgumentParser, Trainer, TrainingArguments, LlamaForCausalLM, LlamaConfig args = HfArgumentParser(TrainingArguments) training_args = cast(TrainingArguments, args.parse_args_into_dataclasses())[0] print(training_args, flush=True) config = LlamaConfig( vocab_size=128, hidden_size=128, intermediate_size=128*2, num_hidden_layers=2 ) model = LlamaForCausalLM(config).cuda().bfloat16() train_dataset = [{"input_ids": torch.randint(0, 128, (128,)), "labels": torch.randint(0, 128, (128,))} for i in range(16)] trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=train_dataset, ) trainer.train() trainer.save_state()
fsdp.yaml
compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP fsdp_config: fsdp_sharding_strategy: FULL_SHARD fsdp_activation_checkpointing: false fsdp_use_orig_params: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE fsdp_offload_params: false fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer,Embedding mixed_precision: 'no' enable_cpu_affinity: false machine_rank: 0 main_training_function: main num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false
launch script
export CUDA_VISIBLE_DEVICES=0,1 export DISABLE_KERNEL_MAPPING=1 OUTPUT_DIR=test_fsdp mkdir -p $OUTPUT_DIR OMP_NUM_THREADS=8 accelerate launch --main_process_port 40129 --config_file fsdp.yaml \ train.py \ --torch_compile_mode default \ --do_train \ --optim adamw_torch_fused \ --learning_rate 1e-3 \ --weight_decay 0 \ --lr_scheduler_type constant_with_warmup \ --warmup_ratio 0.1 \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --eval_on_start 0 \ --eval_strategy epoch \ --eval_steps 1 \ --save_strategy epoch \ --save_only_model 1 \ --greater_is_better False \ --logging_strategy steps \ --logging_steps 1 \ --include_tokens_per_second \ --output_dir $OUTPUT_DIR \ --num_train_epochs 1 \ --seed 0 \ --report_to none \ > $OUTPUT_DIR/training.log 2>&1
Expected behavior
file test_fsdp/checkpoint-2/config.json
exists
run
from safetensors import safe_open path = "test_fsdp/checkpoint-2/model.safetensors" file = safe_open(path, framework="pt") print(file.keys()) lm_head = "lm_head.weight" if lm_head not in file.keys(): lm_head += "_orig_mod." print(file.get_tensor(lm_head).shape)
expected to get
['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.norm.weight'] torch.Size([128, 128])
instead of
['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight'] torch.Size([8128])
If the --eval_strategy epoch
in the launch script is changed to --eval_strategy no
, then
['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight'] torch.Size([128, 128])
Issue Details
[FSDP][torch.compile] accelerator.unwrap_model and trainer._save work incorrectly when FSDP + torch.compile
System Info
transformers 4.51.3 accelerate 1.6.0
Who can help?
@zach-huggingface @SunMarc
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
To use torch.compile, you need to either uninstall the kernels library or set the environment variable DISABLE_KERNEL_MAPPING to 1.
train.py
from typing import cast import torch from transformers import HfArgumentParser, Trainer, TrainingArguments, LlamaForCausalLM, LlamaConfig args = HfArgumentParser(TrainingArguments) training_args = cast(TrainingArguments, args.parse_args_into_dataclasses())[0] print(training_args, flush=True) config = LlamaConfig( vocab_size=128, hidden_size=128, intermediate_size=128*2, num_hidden_layers=2 ) model = LlamaForCausalLM(config).cuda().bfloat16() train_dataset = [{"input_ids": torch.randint(0, 128, (128,)), "labels": torch.randint(0, 128, (128,))} for i in range(16)] trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=train_dataset, ) trainer.train() trainer.save_state()
fsdp.yaml
compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP fsdp_config: fsdp_sharding_strategy: FULL_SHARD fsdp_activation_checkpointing: false fsdp_use_orig_params: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE fsdp_offload_params: false fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer,Embedding mixed_precision: 'no' enable_cpu_affinity: false machine_rank: 0 main_training_function: main num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false
launch script
export CUDA_VISIBLE_DEVICES=0,1 export DISABLE_KERNEL_MAPPING=1 OUTPUT_DIR=test_fsdp mkdir -p $OUTPUT_DIR OMP_NUM_THREADS=8 accelerate launch --main_process_port 40129 --config_file fsdp.yaml \ train.py \ --torch_compile_mode default \ --do_train \ --optim adamw_torch_fused \ --learning_rate 1e-3 \ --weight_decay 0 \ --lr_scheduler_type constant_with_warmup \ --warmup_ratio 0.1 \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --eval_on_start 0 \ --eval_strategy epoch \ --eval_steps 1 \ --save_strategy epoch \ --save_only_model 1 \ --greater_is_better False \ --logging_strategy steps \ --logging_steps 1 \ --include_tokens_per_second \ --output_dir $OUTPUT_DIR \ --num_train_epochs 1 \ --seed 0 \ --report_to none \ > $OUTPUT_DIR/training.log 2>&1
Expected behavior
file test_fsdp/checkpoint-2/config.json
exists
run
from safetensors import safe_open path = "test_fsdp/checkpoint-2/model.safetensors" file = safe_open(path, framework="pt") print(file.keys()) lm_head = "lm_head.weight" if lm_head not in file.keys(): lm_head += "_orig_mod." print(file.get_tensor(lm_head).shape)
expected to get
['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.norm.weight'] torch.Size([128, 128])
instead of
['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight'] torch.Size([8128])
If the --eval_strategy epoch
in the launch script is changed to --eval_strategy no
, then
['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight'] torch.Size([128, 128])