Skip to content

Commit b5742d3

Browse files
kashifcarmoccaawaelchli
authored andcommitted
Relax bitsandbytes requirements (#946)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
1 parent 4b2e02e commit b5742d3

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

litgpt/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> i
3939
if requires_grad is None or p.requires_grad == requires_grad:
4040
if hasattr(p, "quant_state"):
4141
# bitsandbytes 4bit layer support
42-
total += math.prod(p.quant_state[1])
42+
total += math.prod(p.quant_state.shape)
4343
else:
4444
total += p.numel()
4545
return total

pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ license = { file = "LICENSE" }
1010

1111
dependencies = [
1212
"torch>=2.2.0",
13-
"lightning @ git+https://github.com/Lightning-AI/lightning@b19c3a961c79028d7c39a4f1ff1c2df991406d1d",
13+
"lightning @ git+https://github.com/Lightning-AI/lightning@75553845c6bbcc305fbae38a46ef4e532e4ac85a",
1414
# TODO: install from PyPI when https://github.com/omni-us/jsonargparse/pull/466 is released
1515
"jsonargparse[signatures] @ git+https://github.com/omni-us/jsonargparse",
1616
]
@@ -32,8 +32,7 @@ test = [
3232
"protobuf",
3333
]
3434
all = [
35-
"bitsandbytes==0.41.0", # quantization
36-
"scipy", # required by bitsandbytes
35+
"bitsandbytes==0.42.0", # quantization
3736
"sentencepiece", # llama-based models
3837
"tokenizers", # pythia, falcon, redpajama
3938
"datasets", # eval

tests/test_model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

33
import sys
4+
from copy import deepcopy
45
from functools import partial
56
from pathlib import Path
67
from urllib.request import urlretrieve
@@ -698,7 +699,7 @@ def test_model_kv_cache_amp():
698699

699700

700701
@RunIf(min_cuda_gpus=1)
701-
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
702+
@pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs])
702703
@torch.inference_mode()
703704
def test_sdpa_choice(config):
704705
from torch.backends.cuda import (
@@ -754,7 +755,7 @@ def assert_sdpa_backend(original_fn, q, k, v, mask):
754755

755756

756757
@RunIf(min_cuda_gpus=1)
757-
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
758+
@pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs])
758759
@torch.inference_mode()
759760
def test_sdpa_choice_kv_cache(config):
760761
from torch.backends.cuda import (

tests/test_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def test_num_parameters():
158158

159159
@RunIf(min_cuda_gpus=1)
160160
@pytest.mark.parametrize("mode", ["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"])
161-
@pytest.mark.skip("To be fixed")
162161
def test_num_parameters_bitsandbytes(mode):
163162
from lightning.fabric.plugins import BitsandbytesPrecision
164163

tutorials/quantize.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check
4646
Uses the normalized float 4 (nf4) data type. This is recommended over "fp4" based on the paper's experimental results and theoretical analysis.
4747

4848
```bash
49-
pip install scipy bitsandbytes # scipy is required until https://github.com/TimDettmers/bitsandbytes/pull/525 is released
49+
pip install bitsandbytes
5050

5151
litgpt generate base --quantize bnb.nf4 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256
5252
...
@@ -62,7 +62,7 @@ Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check
6262
In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
6363

6464
```bash
65-
pip install scipy bitsandbytes # scipy is required until https://github.com/TimDettmers/bitsandbytes/pull/525 is released
65+
pip install bitsandbytes
6666

6767
litgpt generate base --quantize bnb.nf4-dq --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256
6868
...
@@ -77,7 +77,7 @@ Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check
7777
Uses pure FP4 quantization.
7878

7979
```bash
80-
pip install scipy bitsandbytes # scipy is required until https://github.com/TimDettmers/bitsandbytes/pull/525 is released
80+
pip install bitsandbytes
8181

8282
litgpt generate base --quantize bnb.fp4 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256
8383
...
@@ -93,7 +93,7 @@ Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check
9393
In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
9494

9595
```bash
96-
pip install scipy bitsandbytes # scipy is required until https://github.com/TimDettmers/bitsandbytes/pull/525 is released
96+
pip install bitsandbytes
9797

9898
litgpt generate base --quantize bnb.fp4-dq --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256
9999
...
@@ -106,7 +106,7 @@ Memory used: 5.38 GB
106106
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2110.02861) to learn more about how it works.
107107

108108
```bash
109-
pip install scipy bitsandbytes # scipy is required until https://github.com/TimDettmers/bitsandbytes/pull/525 is released
109+
pip install bitsandbytes
110110

111111
litgpt generate base --quantize bnb.int8 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision 16-true --max_new_tokens 256
112112
...

0 commit comments

Comments
 (0)