[Prim][CINN] Use cpp flag store prim states to ensure consistency#71837
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Pull Request Overview
This PR addresses a bug fix in the prim flag handling for CINN by unifying the management of prim states using C++ flags to ensure consistency. Key changes include updating internal API functions (set_prim* functions) to accept an optional print_flag parameter, replacing usages of the legacy check_and_set_prim_all_enabled with a new __check_and_set_prim_all_enabled, and adjusting test cases (including subprocess tests and deprecated tests) to reflect these refactorings.
Reviewed Changes
Copilot reviewed 5 out of 13 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| test/prim/pir_prim/test_pir_prim_flags.py | Added comprehensive subprocess tests to verify prim flag states based on environment variables and API calls. |
| python/paddle/base/core.py | Refactored internal prim flag functions, adding a new __check_and_set_prim_all_enabled and modifying API signatures to include an optional print_flag parameter. |
| python/paddle/jit/dy2static/utils.py | Commented out calls to the legacy check_and_set_prim_all_enabled to avoid unintended side effects. |
| test/deprecated/prim/prim/flags/test_prim_flags_case_deprecated.py | Updated deprecated tests to use the new __check_and_set_prim_all_enabled function. |
| test/deprecated/prim/prim/flags/test_prim_flags_deprecated.py (second file) | Reduced redundant prim flag tests and focused on eager prim enabling. |
Files not reviewed (8)
- paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.cc: Language not supported
- paddle/common/flags.cc: Language not supported
- paddle/fluid/prim/utils/static/static_global_utils.cc: Language not supported
- paddle/fluid/prim/utils/static/static_global_utils.h: Language not supported
- paddle/fluid/prim/utils/utils.cc: Language not supported
- paddle/fluid/prim/utils/utils.h: Language not supported
- paddle/fluid/pybind/pybind.cc: Language not supported
- test/prim/pir_prim/CMakeLists.txt: Language not supported
Comments suppressed due to low confidence (3)
test/deprecated/prim/prim/flags/test_prim_flags_case_deprecated.py:57
- Using a private function '__check_and_set_prim_all_enabled' directly in tests may lead to brittleness; consider using a public API or introducing a dedicated testing wrapper for flag synchronization.
core.__check_and_set_prim_all_enabled()
python/paddle/jit/dy2static/utils.py:679
- Ensure that deprecating 'check_and_set_prim_all_enabled' is fully reflected in the API usage; if removal is intended, verify that commenting out this call does not impact the overall prim flag consistency.
# core.check_and_set_prim_all_enabled(True)
python/paddle/base/core.py:587
- In '__check_and_set_prim_all_enabled', the sequential handling of FLAGS_prim_all then FLAGS_prim_forward and FLAGS_prim_backward may lead to conflicting flag states if both group and individual flags are set; please confirm that this ordering is the intended behavior.
prim_bwd_env = os.getenv("FLAGS_prim_backward")
40e173b to
802dc5a
Compare
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
prim 相关 flag 的数据统一使用 C++ flag,避免数据不同步导致出问题
FLAGS_prim_all,这是冗余数据,它应该只是FLAGS_prim_forward && FLAGS_prim_backward,使用时应该重新计算,而不是考虑什么时候怎么同步之类的,不过只是这个 C++ FLAG 被移除,环境变量仍然生效StaticCompositeContext::enable_fwd_prim_和StaticCompositeContext::enable_bwd_prim_,使用 flag 可以确保 CINN 那边获取的是一致的FLAGS_prim_all向 C++ flag 同步一次,确保FLAGS_prim_all环境变量生效,为了能够支持FLAGS_prim_forward=False && FLAGS_prim_all=True这种 case,同步时也需要把FLAGS_prim_forward和FLAGS_prim_backward同步一下,但,只应该同步这一次,其他地方不允许使用这个 API(__check_and_set_prim_all_enabled),否则将会导致通过 API 设置的效果失效!本 PR 暂时禁掉了如下两个单测,之前因为机制问题导致对比的两者跑了相同模式,进而隐藏了的精度问题,本 PR 暂时禁用,不阻塞合入
test/ir/pir/cinn/test_cinn_group_norm.py之前跑的都是拆解过的test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_5_st.py之前跑的都是未拆解过的Ecosystem updates
Related links
CINN 切默认系列 PR
build_strategy.build_cinn_passand replace it withbackendoption #71815FLAGS_use_cinnviaos.getenv#71817