| """
|
| Setup script for BitLinear PyTorch extension.
|
|
|
| This script builds the C++/CUDA extension using PyTorch's built-in
|
| cpp_extension utilities. It handles:
|
| - CPU-only builds (development)
|
| - CUDA builds (production)
|
| - Conditional compilation based on CUDA availability
|
| """
|
|
|
| import os
|
| import torch
|
| from setuptools import setup, find_packages
|
| from torch.utils.cpp_extension import (
|
| BuildExtension,
|
| CppExtension,
|
| CUDAExtension,
|
| CUDA_HOME,
|
| )
|
|
|
|
|
| VERSION = "0.1.0"
|
| DESCRIPTION = "BitLinear: Ultra-Low-Precision Linear Layers for PyTorch"
|
| LONG_DESCRIPTION = """
|
| A research-grade PyTorch extension for ultra-low-precision (1.58-bit) ternary
|
| linear layers inspired by BitNet and recent JMLR work on ternary representations
|
| of neural networks.
|
|
|
| Features:
|
| - Drop-in replacement for nn.Linear with ternary weights
|
| - 20x memory compression
|
| - Optimized CUDA kernels for GPU acceleration
|
| - Greedy ternary decomposition for improved expressiveness
|
| """
|
|
|
|
|
| def cuda_is_available():
|
| """Check if CUDA is available for compilation."""
|
| return torch.cuda.is_available() and CUDA_HOME is not None
|
|
|
|
|
| def get_extensions():
|
| """
|
| Build extension modules based on CUDA availability.
|
|
|
| Returns:
|
| List of extension modules to compile
|
| """
|
|
|
| source_dir = os.path.join("bitlinear", "cpp")
|
| sources = [os.path.join(source_dir, "bitlinear.cpp")]
|
|
|
|
|
| extra_compile_args = {
|
| "cxx": ["-O3", "-std=c++17"],
|
| }
|
|
|
|
|
| define_macros = []
|
|
|
| if cuda_is_available():
|
| print("CUDA detected, building with GPU support")
|
|
|
|
|
| sources.append(os.path.join(source_dir, "bitlinear_kernel.cu"))
|
|
|
|
|
| extra_compile_args["nvcc"] = [
|
| "-O3",
|
| "-std=c++17",
|
| "--use_fast_math",
|
| "-gencode=arch=compute_70,code=sm_70",
|
| "-gencode=arch=compute_75,code=sm_75",
|
| "-gencode=arch=compute_80,code=sm_80",
|
| "-gencode=arch=compute_86,code=sm_86",
|
| "-gencode=arch=compute_89,code=sm_89",
|
| "-gencode=arch=compute_90,code=sm_90",
|
| ]
|
|
|
|
|
| define_macros.append(("WITH_CUDA", None))
|
|
|
|
|
| extension = CUDAExtension(
|
| name="bitlinear_cpp",
|
| sources=sources,
|
| extra_compile_args=extra_compile_args,
|
| define_macros=define_macros,
|
| )
|
| else:
|
| print("CUDA not detected, building CPU-only version")
|
|
|
|
|
| extension = CppExtension(
|
| name="bitlinear_cpp",
|
| sources=sources,
|
| extra_compile_args=extra_compile_args["cxx"],
|
| define_macros=define_macros,
|
| )
|
|
|
| return [extension]
|
|
|
|
|
|
|
| def read_requirements():
|
| """Read requirements from requirements.txt if it exists."""
|
| req_file = "requirements.txt"
|
| if os.path.exists(req_file):
|
| with open(req_file, "r") as f:
|
| return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
| return []
|
|
|
|
|
|
|
| setup(
|
| name="bitlinear",
|
| version=VERSION,
|
| author="BitLinear Contributors",
|
| description=DESCRIPTION,
|
| long_description=LONG_DESCRIPTION,
|
| long_description_content_type="text/markdown",
|
| url="https://github.com/yourusername/bitlinear",
|
| packages=find_packages(),
|
| ext_modules=get_extensions(),
|
| cmdclass={
|
| "build_ext": BuildExtension.with_options(no_python_abi_suffix=True)
|
| },
|
| install_requires=[
|
| "torch>=2.0.0",
|
| "numpy>=1.20.0",
|
| ],
|
| extras_require={
|
| "dev": [
|
| "pytest>=7.0.0",
|
| "pytest-cov>=4.0.0",
|
| "black>=22.0.0",
|
| "flake8>=5.0.0",
|
| "mypy>=0.990",
|
| ],
|
| "test": [
|
| "pytest>=7.0.0",
|
| "pytest-cov>=4.0.0",
|
| ],
|
| },
|
| python_requires=">=3.8",
|
| classifiers=[
|
| "Development Status :: 3 - Alpha",
|
| "Intended Audience :: Science/Research",
|
| "Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| "License :: OSI Approved :: MIT License",
|
| "Programming Language :: Python :: 3",
|
| "Programming Language :: Python :: 3.8",
|
| "Programming Language :: Python :: 3.9",
|
| "Programming Language :: Python :: 3.10",
|
| "Programming Language :: Python :: 3.11",
|
| "Programming Language :: C++",
|
| "Programming Language :: Python :: Implementation :: CPython",
|
| ],
|
| keywords="pytorch deep-learning quantization ternary bitnet transformer",
|
| project_urls={
|
| "Bug Reports": "https://github.com/yourusername/bitlinear/issues",
|
| "Source": "https://github.com/yourusername/bitlinear",
|
| "Documentation": "https://github.com/yourusername/bitlinear/blob/main/README.md",
|
| },
|
| )
|
|
|