Skip to content

Official code implementation of "Low-Rank Continual Personalization of Diffusion Models".

Notifications You must be signed in to change notification settings

luk-st/continual-lora

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Low-Rank Continual Personalization of Diffusion Models

arXiv

Official code implementation of "Low-Rank Continual Personalization of Diffusion Models".

We propose to study the effectiveness of LoRA weights initialization and merging under the strict continual learning regime where only the model or model with a single adapter is passed between tasks. In a series of experiments, we compare (1) Naïve continual fine-tuning of the low-rank adapter, and three simple merging baselines mitigating forgetting of concepts and styles: (2) consecutive merging of task-specific adapters, (3) merging LoRAs initialized with orthogonal weights, and (4) merging through a selection of weights with the highest magnitude for the task. Our experiments indicate that adding multiple adapters in a Naïve way can lead to a situation where a model converges, in its performance, towards its base form, while all the evaluated techniques mitigate this issue.

🔥 Updates

  • [2025.02.16] Code released.
  • [2024.10.07] Paper released on arXiv.

⚙️ Environment setup

Create conda enviroment and activate it:

conda create -p <ENV_PATH> python=3.11
conda activate <ENV_PATH>

Install all dependencies:

pip install -r requirements.txt

Note ⚠️ Change settings in the file: python3.11/site-packages/diffusers/pipelines/pipeline_utils.py, line ~278 (DiffusionPipeline.save_pretrained()):

save_kwargs = {"max_shard_size": "15GB"}

📥 Download

CSD Model

Download model used for style metric:

pip install gdown
gdown 1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46
mv checkpoint.pth res/csd_checkpoint.pth

Style dataset

Download the Unlearn Dataset and place it in the /data/style_unlearn directory.

To generate style datasets, run:

python3 preprocess/generate_style_dataset.py

Object dataset

Clone Dreambooth repo to data/dreambooth directory.

git clone https://github.com/google/dreambooth data/dreambooth

To generate object datasets, run:

python3 preprocess/generate_object_dataset.py

🚀 Models training

To train LoRA models for subjects and styles, run all experiments with different orders and object/style seeds:

sh slurm/run_all_objects.sh
sh slurm/run_all_styles.sh

To train a specific model for either objects or styles, use the following commands:

sbatch slurm/sbatch_train_obj.sh
sbatch slurm/sbatch_train_style.sh

🎨 Sampling

Run sampling for all trained object or style models:

sh slurm/run_sampling_all_objects.sh
sh slurm/run_sampling_all_styles.sh

📊 Evaluation

Evaluate all models (by default all style models):

sh slurm/run_eval.sh

📚 Citation

If you find this work useful, please cite:

@misc{staniszewski2024lowrankcontinualpersonalizationdiffusion,
      title={Low-Rank Continual Personalization of Diffusion Models}, 
      author={Łukasz Staniszewski and Katarzyna Zaleska and Kamil Deja},
      year={2024},
      eprint={2410.04891},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.04891}, 
}

🙏 Credits

The repository contains code from task_vectors and magmax.

About

Official code implementation of "Low-Rank Continual Personalization of Diffusion Models".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published