mboss's picture
Update inference to latest
4d8c3d6
import argparse
import os
from contextlib import nullcontext
import torch
from PIL import Image
from tqdm import tqdm
from transparent_background import Remover
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
from spar3d.system import SPAR3D
from spar3d.utils import foreground_crop, get_device, remove_background
def check_positive(value):
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
return ivalue
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"image", type=str, nargs="+", help="Path to input image(s) or folder."
)
parser.add_argument(
"--device",
default=get_device(),
type=str,
help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
)
parser.add_argument(
"--pretrained-model",
default="stabilityai/stable-point-aware-3d",
type=str,
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'",
)
parser.add_argument(
"--foreground-ratio",
default=1.3,
type=float,
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
)
parser.add_argument(
"--output-dir",
default="output/",
type=str,
help="Output directory to save the results. Default: 'output/'",
)
parser.add_argument(
"--texture-resolution",
default=1024,
type=int,
help="Texture atlas resolution. Default: 1024",
)
parser.add_argument(
"--low-vram-mode",
action="store_true",
help=(
"Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. "
"This mode will reduce the VRAM consumption to roughly 7GB but in exchange "
"the model will be slower. Default: False"
),
)
remesh_choices = ["none"]
if TRIANGLE_REMESH_AVAILABLE:
remesh_choices.append("triangle")
if QUAD_REMESH_AVAILABLE:
remesh_choices.append("quad")
parser.add_argument(
"--remesh_option",
choices=remesh_choices,
default="none",
help="Remeshing option",
)
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
parser.add_argument(
"--reduction_count_type",
choices=["keep", "vertex", "faces"],
default="keep",
help="Vertex count type",
)
parser.add_argument(
"--target_count",
type=check_positive,
help="Selected target count.",
default=2000,
)
parser.add_argument(
"--batch_size", default=1, type=int, help="Batch size for inference"
)
args = parser.parse_args()
# Ensure args.device contains cuda
devices = ["cuda", "mps", "cpu"]
if not any(args.device in device for device in devices):
raise ValueError("Invalid device. Use cuda, mps or cpu")
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
device = args.device
if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
device = "cpu"
print("Device used: ", device)
model = SPAR3D.from_pretrained(
args.pretrained_model,
config_name="config.yaml",
weight_name="model.safetensors",
low_vram_mode=args.low_vram_mode,
)
model.to(device)
model.eval()
bg_remover = Remover(device=device)
images = []
idx = 0
for image_path in args.image:
def handle_image(image_path, idx):
image = remove_background(
Image.open(image_path).convert("RGBA"), bg_remover
)
image = foreground_crop(image, args.foreground_ratio)
os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
image.save(os.path.join(output_dir, str(idx), "input.png"))
images.append(image)
if os.path.isdir(image_path):
image_paths = [
os.path.join(image_path, f)
for f in os.listdir(image_path)
if f.endswith((".png", ".jpg", ".jpeg"))
]
for image_path in image_paths:
handle_image(image_path, idx)
idx += 1
else:
handle_image(image_path, idx)
idx += 1
vertex_count = (
-1
if args.reduction_count_type == "keep"
else (
args.target_count
if args.reduction_count_type == "vertex"
else args.target_count // 2
)
)
for i in tqdm(range(0, len(images), args.batch_size)):
image = images[i : i + args.batch_size]
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
with (
torch.autocast(device_type=device, dtype=torch.bfloat16)
if "cuda" in device
else nullcontext()
):
mesh, glob_dict = model.run_image(
image,
bake_resolution=args.texture_resolution,
remesh=args.remesh_option,
vertex_count=vertex_count,
return_points=True,
)
if torch.cuda.is_available():
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
elif torch.backends.mps.is_available():
print(
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
)
if len(image) == 1:
out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
mesh.export(out_mesh_path, include_normals=True)
out_points_path = os.path.join(output_dir, str(i), "points.ply")
glob_dict["point_clouds"][0].export(out_points_path)
else:
for j in range(len(mesh)):
out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
mesh[j].export(out_mesh_path, include_normals=True)
out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
glob_dict["point_clouds"][j].export(out_points_path)