mirror of
https://git.savannah.gnu.org/git/guix.git
synced 2025-01-20 06:37:08 +01:00
gnu: Add qnnpack-pytorch.
This is an internal fork of QNNPACK in the PyTorch source tree. * gnu/packages/machine-learning.scm (%python-pytorch-version): New variable. (%python-pytorch-src): New variable. (qnnpack-pytorch): New variable.
This commit is contained in:
parent
0c313244e0
commit
b77c772a3e
1 changed files with 127 additions and 0 deletions
|
@ -4334,6 +4334,133 @@ (define-public ideep-pytorch
|
|||
PyTorch.")
|
||||
(license license:expat)))
|
||||
|
||||
(define %python-pytorch-version "2.2.1")
|
||||
|
||||
(define %python-pytorch-src
|
||||
(origin
|
||||
(method git-fetch)
|
||||
(uri (git-reference
|
||||
(url "https://github.com/pytorch/pytorch")
|
||||
(commit (string-append "v" %python-pytorch-version))))
|
||||
(file-name (git-file-name "python-pytorch" %python-pytorch-version))
|
||||
(sha256
|
||||
(base32
|
||||
"03mm0pwwb5lxdsmmiw3cch9fijgjw81kmmc4ln9rlyazkm7l1r48"))
|
||||
(modules '((guix build utils)))
|
||||
(snippet
|
||||
'(begin
|
||||
;; Bundled or unused code
|
||||
(for-each
|
||||
(lambda (dir)
|
||||
(when (file-exists? dir)
|
||||
(delete-file-recursively dir)))
|
||||
'("android"
|
||||
"aten/src/ATen/native/cuda/cutlass_extensions"
|
||||
"aten/src/ATen/native/quantized/cpu/qnnpack"
|
||||
"caffe2/mobile/contrib/libopencl-stub"
|
||||
"caffe2/mobile/contrib/libvulkan-stub"
|
||||
"third_party"))
|
||||
|
||||
;; Autogenerated files
|
||||
(for-each
|
||||
delete-file
|
||||
'("aten/src/ATen/nnapi/nnapi_wrapper.cpp"
|
||||
"aten/src/ATen/nnapi/nnapi_wrapper.h"
|
||||
"caffe2/mobile/contrib/ios/mpscnn/mpscnn_kernels.h"
|
||||
"caffe2/proto/caffe2_legacy_pb2.pyi"
|
||||
"caffe2/proto/caffe2_pb2.pyi"
|
||||
"caffe2/proto/hsm_pb2.pyi"
|
||||
"caffe2/proto/metanet_pb2.pyi"
|
||||
"caffe2/proto/predictor_consts_pb2.pyi"
|
||||
"caffe2/proto/prof_dag_pb2.pyi"
|
||||
"caffe2/proto/torch_pb2.pyi"
|
||||
;; These files contain just lists of floating point values and
|
||||
;; might be as well hand-written.
|
||||
;; "test/cpp/api/init_baseline.h"
|
||||
;; "test/cpp/api/optim_baseline.h"
|
||||
"test/mobile/test_upgrader_bytecode_table_example.cpp"
|
||||
"torch/csrc/jit/mobile/upgrader_mobile.cpp"
|
||||
"torch/csrc/jit/runtime/decomposition_registry_util.cpp"
|
||||
"torch/csrc/jit/runtime/serialized_shape_function_registry.cpp"
|
||||
"torch/csrc/jit/tensorexpr/external_functions_codegen.cpp"
|
||||
"torch/csrc/jit/serialization/mobile_bytecode_generated.h"))
|
||||
(delete-file-recursively ".github")
|
||||
(for-each
|
||||
(lambda (dir)
|
||||
(for-each
|
||||
delete-file
|
||||
(find-files dir "\\.cu$")))
|
||||
'("aten/src/ATen/native/transformers/cuda/flash_attn/kernels"
|
||||
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels"))))))
|
||||
|
||||
(define-public qnnpack-pytorch
|
||||
(package
|
||||
(inherit qnnpack)
|
||||
(name "qnnpack-pytorch")
|
||||
(version (string-append "pytorch-" %python-pytorch-version))
|
||||
(source
|
||||
(origin
|
||||
(inherit %python-pytorch-src)
|
||||
(patches '())
|
||||
(modules '((guix build utils)
|
||||
(srfi srfi-26)
|
||||
(ice-9 ftw)))
|
||||
(snippet
|
||||
'(begin
|
||||
(rename-file "aten/src/ATen/native/quantized/cpu/qnnpack"
|
||||
"../qnnpack")
|
||||
(let ((outdir (getcwd)))
|
||||
(chdir "..")
|
||||
(rename-file outdir "dummy")
|
||||
(rename-file "qnnpack" outdir)
|
||||
(chdir outdir)
|
||||
(delete-file-recursively "deps"))))))
|
||||
(arguments
|
||||
(substitute-keyword-arguments (package-arguments qnnpack)
|
||||
((#:phases phases #~%standard-phases)
|
||||
#~(modify-phases %standard-phases
|
||||
(add-after 'unpack 'patch-cmake
|
||||
(lambda _
|
||||
(substitute* "CMakeLists.txt"
|
||||
(("project\\(.*" orig)
|
||||
(apply
|
||||
string-append
|
||||
orig "\n"
|
||||
(map (lambda (name)
|
||||
(string-append
|
||||
"option(" name " \"\" ON)\n"))
|
||||
'("USE_SYSTEM_CPUINFO" "USE_SYSTEM_FP16" "USE_SYSTEM_FXDIV"
|
||||
"USE_SYSTEM_PSIMD" "USE_SYSTEM_PTHREADPOOL"))))
|
||||
(("if.*SOURCE_DIR.*")
|
||||
"if(FALSE)\n")
|
||||
(("if\\(NOT TARGET (clog|gtest|benchmark).*")
|
||||
"if(FALSE)\n")
|
||||
(("target_link_libraries.*(fxdiv|psimd|fp16)\\).*")
|
||||
"")
|
||||
(("(target_link_libraries.*) fp16 (.*)" _ before after)
|
||||
(string-append before " " after)))))
|
||||
(add-after 'unpack 'fix-cstring-include
|
||||
(lambda _
|
||||
(substitute* "include/pack_block_sparse.h"
|
||||
(("#include.*<vector>.*" orig)
|
||||
(string-append orig "\n#include <cstring>\n")))))
|
||||
(add-after 'install 'install-missing-headers
|
||||
(lambda _
|
||||
(for-each
|
||||
(lambda (name)
|
||||
(install-file (string-append "../source/include/" name)
|
||||
(string-append #$output "/include")))
|
||||
'("pack_block_sparse.h"
|
||||
"pytorch_qnnpack.h"
|
||||
"qnnpack_func.h"))
|
||||
(copy-recursively
|
||||
"../source/src/qnnpack"
|
||||
(string-append #$output "/include/qnnpack"))))))
|
||||
;; Some tests occasionally fail on i686 due to floating point rounding.
|
||||
((#:tests? _ #t)
|
||||
(not (string-prefix? "i686" (or (%current-target-system)
|
||||
(%current-system)))))))))
|
||||
|
||||
;; Please also update python-torchvision when updating this package.
|
||||
(define-public python-pytorch
|
||||
(package
|
||||
|
|
Loading…
Reference in a new issue