Skip to content

Instantly share code, notes, and snippets.

@f0ster
Created April 28, 2024 14:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save f0ster/e8f94535b34119d86044ebed0becc8cf to your computer and use it in GitHub Desktop.
Save f0ster/e8f94535b34119d86044ebed0becc8cf to your computer and use it in GitHub Desktop.
CLI for sharding and publishing models to huggingface
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import os
import argparse
def main():
# Parse command line arguments
args = parse_args()
# Load environment variables for sensitive information
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise EnvironmentError("HF_TOKEN is not set in environment variables.")
model_name = args.model_name
save_directory = args.save_directory if args.save_directory else f"./saved_models/{model_name.replace('/', '_')}"
checkpoint_directory = args.checkpoint_directory if args.checkpoint_directory else f"./model_checkpoints/{model_name.replace('/', '_')}"
max_shard_size = args.max_shard_size
new_model_name = f"{model_name}-sharded"
# Ensure directories exist
os.makedirs(save_directory, exist_ok=True)
os.makedirs(checkpoint_directory, exist_ok=True)
# Setup tokenizer and model
tokenizer = load_tokenizer(model_name)
model = load_model(model_name)
# Initialize Accelerator
accelerator = Accelerator()
# Save the model in shards with the specified max shard size
save_model_shards(accelerator, model, save_directory, max_shard_size)
# Load the model with specific device mapping
model = load_model_from_checkpoint(model, checkpoint_directory)
# Push tokenizer and model to Hugging Face Hub
push_to_hub(tokenizer, new_model_name, HF_TOKEN)
push_to_hub(model, new_model_name, HF_TOKEN)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Shard and push model to Hugging Face Hub.")
parser.add_argument("--model-name", type=str, required=True, help="Model identifier on Hugging Face.")
parser.add_argument("--save-directory", type=str, default=None, help="Directory to save sharded model. Defaults to a directory based on model name.")
parser.add_argument("--checkpoint-directory", type=str, default=None, help="Directory to load model checkpoint. Defaults to a directory based on model name.")
parser.add_argument("--max-shard-size", type=str, default="200MB", help="Maximum shard size (e.g., '200MB').")
return parser.parse_args()
def load_tokenizer(model_name):
"""Load tokenizer from pretrained."""
return AutoTokenizer.from_pretrained(model_name)
def load_model(model_name):
"""Load model from pretrained with specific configurations."""
return AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
def save_model_shards(accelerator, model, directory, max_shard_size):
"""Save the model into smaller sharded files."""
accelerator.save_model(
model=model,
save_directory=directory,
max_shard_size=max_shard_size
)
def load_model_from_checkpoint(model, checkpoint):
"""Load model from a checkpoint directory."""
device_map = {"": 'cpu'}
return load_checkpoint_and_dispatch(
model,
checkpoint=checkpoint,
device_map=device_map,
no_split_module_classes=["Block"]
)
def push_to_hub(obj, model_name, token):
"""Push objects to Hugging Face Hub."""
obj.push_to_hub(
model_name,
token=token
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment