diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..240447c073b0e898646f6ccb5f9945efebb6f9ef 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..765d2f6ac0885a28fb2609b3d6c1d7942946e6b8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,167 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv*/ +env/ +venv*/ +ENV/ +env.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +.vs/ +.idea/ +.vscode/ + +stabilityai/ +output/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..53574e1eb2ac72750a2f4299764e93695e3c506b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: check-ast + - id: check-merge-conflict + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.5 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..a889fd6b2adf60bbbe8d687e011803892b828eab --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,51 @@ +STABILITY AI COMMUNITY LICENSE AGREEMENT +Last Updated: July 5, 2024 + + +I. INTRODUCTION + +This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below. + + +This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent). + + +By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf. + +II. RESEARCH & NON-COMMERCIAL USE LICENSE + +Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing. + +III. COMMERCIAL USE LICENSE + +Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations. +If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you. + +IV. GENERAL TERMS + +Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. +a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified. +b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). +c. Intellectual Property. +(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. +(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. +(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. +(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. +(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback. +d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. +e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. +g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement. + +V. DEFINITIONS + +"Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity. +"Agreement" means this Stability AI Community License Agreement. +"AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. +"Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model. +"Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models. +"Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time. +"Stability AI" or "we" means Stability AI Ltd. and its Affiliates. +"Software" means Stability AI's proprietary software made available under this Agreement now or in the future. +"Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement. +"Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations. diff --git a/README.md b/README.md index b79407426591f191ee9fb7d83d2af0c0f349d785..2c2b39e183fd23979da1671c4772b82201d57422 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ --- -title: Stable Point Aware 3d +title: Stable Point-Aware 3D emoji: ⚡ colorFrom: yellow colorTo: yellow sdk: gradio -sdk_version: 5.9.1 -app_file: app.py +sdk_version: 4.43.0 +app_file: gradio_app.py pinned: false --- diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..178313e4c85c81aa81cc72de1527483e3e8afbf7 --- /dev/null +++ b/__init__.py @@ -0,0 +1,358 @@ +import base64 +import logging +import os +import random +import sys + +import comfy.model_management +import folder_paths +import numpy as np +import torch +import trimesh +from PIL import Image +from trimesh.exchange import gltf + +sys.path.append(os.path.dirname(__file__)) +from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE +from spar3d.system import SPAR3D +from spar3d.utils import foreground_crop + +SPAR3D_CATEGORY = "SPAR3D" +SPAR3D_MODEL_NAME = "stabilityai/spar3d" + + +class SPAR3DLoader: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "load" + RETURN_NAMES = ("spar3d_model",) + RETURN_TYPES = ("SPAR3D_MODEL",) + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def load(self): + device = comfy.model_management.get_torch_device() + model = SPAR3D.from_pretrained( + SPAR3D_MODEL_NAME, + config_name="config.yaml", + weight_name="model.safetensors", + ) + model.to(device) + model.eval() + + return (model,) + + +class SPAR3DPreview: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "preview" + OUTPUT_NODE = True + RETURN_TYPES = () + + @classmethod + def INPUT_TYPES(s): + return {"required": {"mesh": ("MESH",)}} + + def preview(self, mesh): + glbs = [] + for m in mesh: + scene = trimesh.Scene(m) + glb_data = gltf.export_glb(scene, include_normals=True) + glb_base64 = base64.b64encode(glb_data).decode("utf-8") + glbs.append(glb_base64) + return {"ui": {"glbs": glbs}} + + +class SPAR3DSampler: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "predict" + RETURN_NAMES = ("mesh", "pointcloud") + RETURN_TYPES = ("MESH", "POINTCLOUD") + + @classmethod + def INPUT_TYPES(s): + remesh_choices = ["none"] + if TRIANGLE_REMESH_AVAILABLE: + remesh_choices.append("triangle") + if QUAD_REMESH_AVAILABLE: + remesh_choices.append("quad") + + opt_dict = { + "mask": ("MASK",), + "pointcloud": ("POINTCLOUD",), + "target_type": (["none", "vertex", "face"],), + "target_count": ( + "INT", + {"default": 1000, "min": 3, "max": 20000, "step": 1}, + ), + "guidance_scale": ( + "FLOAT", + {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05}, + ), + "seed": ( + "INT", + {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1}, + ), + } + if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: + opt_dict["remesh"] = (remesh_choices,) + + return { + "required": { + "model": ("SPAR3D_MODEL",), + "image": ("IMAGE",), + "foreground_ratio": ( + "FLOAT", + {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01}, + ), + "texture_resolution": ( + "INT", + {"default": 1024, "min": 512, "max": 2048, "step": 256}, + ), + }, + "optional": opt_dict, + } + + def predict( + s, + model, + image, + mask, + foreground_ratio, + texture_resolution, + pointcloud=None, + remesh="none", + target_type="none", + target_count=1000, + guidance_scale=3.0, + seed=42, + ): + if image.shape[0] != 1: + raise ValueError("Only one image can be processed at a time") + + vertex_count = ( + -1 + if target_type == "none" + else (target_count // 2 if target_type == "face" else target_count) + ) + + pil_image = Image.fromarray( + torch.clamp(torch.round(255.0 * image[0]), 0, 255) + .type(torch.uint8) + .cpu() + .numpy() + ) + + if mask is not None: + print("Using Mask") + mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( + np.uint8 + ) + mask_pil = Image.fromarray(mask_np, mode="L") + pil_image.putalpha(mask_pil) + else: + if image.shape[3] != 4: + print("No mask or alpha channel detected, Converting to RGBA") + pil_image = pil_image.convert("RGBA") + + pil_image = foreground_crop(pil_image, foreground_ratio) + + model.cfg.guidance_scale = guidance_scale + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + + print(remesh) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.float16): + if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle": + raise ImportError( + "Triangle remeshing requires gpytoolbox to be installed" + ) + if not QUAD_REMESH_AVAILABLE and remesh == "quad": + raise ImportError("Quad remeshing requires pynim to be installed") + mesh, glob_dict = model.run_image( + pil_image, + bake_resolution=texture_resolution, + pointcloud=pointcloud, + remesh=remesh, + vertex_count=vertex_count, + ) + + if mesh.vertices.shape[0] == 0: + raise ValueError("No subject detected in the image") + + return ( + [mesh], + glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(), + ) + + +class SPAR3DSave: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "save" + OUTPUT_NODE = True + RETURN_TYPES = () + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mesh": ("MESH",), + "filename_prefix": ("STRING", {"default": "SPAR3D"}), + } + } + + def __init__(self): + self.type = "output" + + def save(self, mesh, filename_prefix): + output_dir = folder_paths.get_output_directory() + glbs = [] + for idx, m in enumerate(mesh): + scene = trimesh.Scene(m) + glb_data = gltf.export_glb(scene, include_normals=True) + logging.info(f"Generated GLB model with {len(glb_data)} bytes") + + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(filename_prefix, output_dir) + ) + filename = filename.replace("%batch_num%", str(idx)) + out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") + with open(out_path, "wb") as f: + f.write(glb_data) + glbs.append(base64.b64encode(glb_data).decode("utf-8")) + return {"ui": {"glbs": glbs}} + + +class SPAR3DPointCloudLoader: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "load_pointcloud" + RETURN_TYPES = ("POINTCLOUD",) + RETURN_NAMES = ("pointcloud",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "file": ("STRING", {"default": None}), + } + } + + def load_pointcloud(self, file): + if file is None or file == "": + return (None,) + # Load the mesh using trimesh + mesh = trimesh.load(file) + + # Extract vertices and colors + vertices = mesh.vertices + + # Get vertex colors, defaulting to white if none exist + if mesh.visual.vertex_colors is not None: + colors = ( + mesh.visual.vertex_colors[:, :3] / 255.0 + ) # Convert 0-255 to 0-1 range + else: + colors = np.ones((len(vertices), 3)) + + # Interleave XYZ and RGB values + point_cloud = [] + for vertex, color in zip(vertices, colors): + point_cloud.extend( + [ + float(vertex[0]), + float(vertex[1]), + float(vertex[2]), + float(color[0]), + float(color[1]), + float(color[2]), + ] + ) + + return (point_cloud,) + + +class SPAR3DPointCloudSaver: + CATEGORY = SPAR3D_CATEGORY + FUNCTION = "save_pointcloud" + OUTPUT_NODE = True + RETURN_TYPES = () + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "pointcloud": ("POINTCLOUD",), + "filename_prefix": ("STRING", {"default": "SPAR3D"}), + } + } + + def save_pointcloud(self, pointcloud, filename_prefix): + if pointcloud is None: + return {"ui": {"text": "No point cloud data to save"}} + + # Reshape the flat list into points with XYZ and RGB + points = np.array(pointcloud).reshape(-1, 6) + + # Create vertex array for PLY + vertex_array = np.zeros( + len(points), + dtype=[ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ], + ) + + # Fill vertex array + vertex_array["x"] = points[:, 0] + vertex_array["y"] = points[:, 1] + vertex_array["z"] = points[:, 2] + # Convert RGB from 0-1 to 0-255 range + vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8) + vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8) + vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8) + + # Create PLY object + ply_data = trimesh.PointCloud( + vertices=points[:, :3], colors=points[:, 3:] * 255 + ) + + # Save to file + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(filename_prefix, output_dir) + ) + out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply") + + ply_data.export(out_path) + + return {"ui": {"text": f"Saved point cloud to {out_path}"}} + + +NODE_DISPLAY_NAME_MAPPINGS = { + "SPAR3DLoader": "SPAR3D Loader", + "SPAR3DPreview": "SPAR3D Preview", + "SPAR3DSampler": "SPAR3D Sampler", + "SPAR3DSave": "SPAR3D Save", + "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader", + "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver", +} + +NODE_CLASS_MAPPINGS = { + "SPAR3DLoader": SPAR3DLoader, + "SPAR3DPreview": SPAR3DPreview, + "SPAR3DSampler": SPAR3DSampler, + "SPAR3DSave": SPAR3DSave, + "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader, + "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver, +} + +WEB_DIRECTORY = "./comfyui" + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] diff --git a/demo_files/comp.gif b/demo_files/comp.gif new file mode 100644 index 0000000000000000000000000000000000000000..4d976d4302325570e9b71a621b5525af53ea92f6 --- /dev/null +++ b/demo_files/comp.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6190ca0c3bd164d37152ba985abea53e642fe5e434ca0a932a3b2c4dce698f6b +size 1782375 diff --git a/demo_files/examples/bird.png b/demo_files/examples/bird.png new file mode 100644 index 0000000000000000000000000000000000000000..ff6c56a49390aed0ca22467a670c0e26cb1806b7 --- /dev/null +++ b/demo_files/examples/bird.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83373e2b75ebaad76b6fe093973ea1dc96c92527c8376062cf520ed9215f3e82 +size 560004 diff --git a/demo_files/examples/castle.png b/demo_files/examples/castle.png new file mode 100644 index 0000000000000000000000000000000000000000..5913b8030711f476f9eb73f3eb8e1965c768b578 --- /dev/null +++ b/demo_files/examples/castle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ededd2fe4c122cadfb4f2a485dfd82f83dc1ec6446c7a799d5fc1e1f103ae4b1 +size 204100 diff --git a/demo_files/examples/chest.png b/demo_files/examples/chest.png new file mode 100644 index 0000000000000000000000000000000000000000..660faf908426bfbcabd83e15f0620caf9c3edd75 --- /dev/null +++ b/demo_files/examples/chest.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1eec59b35c63aa50942edff37f0cbdea7d8360cd036a4b7eb9460afdfcbabd9 +size 1499873 diff --git a/demo_files/examples/doll.png b/demo_files/examples/doll.png new file mode 100644 index 0000000000000000000000000000000000000000..5599dd5746914b6704ff3d469ff3135f478f7b8e --- /dev/null +++ b/demo_files/examples/doll.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc5af86defd0a4fd7285e17a0eb8a108b9f33774408c194a594964d8d6e66c26 +size 154688 diff --git a/demo_files/examples/excavator.png b/demo_files/examples/excavator.png new file mode 100644 index 0000000000000000000000000000000000000000..377e6419f35a1012eafb1db40edfd48ab11a62f7 --- /dev/null +++ b/demo_files/examples/excavator.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f68c6ba4a9dc884d3786d98c4f0d835682bad02e85716d3a60fd2feedcb03d8 +size 189762 diff --git a/demo_files/examples/fish.png b/demo_files/examples/fish.png new file mode 100644 index 0000000000000000000000000000000000000000..badfa85df9426c496addb39472ecfcc5dd8d9120 --- /dev/null +++ b/demo_files/examples/fish.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd623d8b654de81e022e3741576a0d08dd26d6ba92ee1989605347ef26c399bb +size 838058 diff --git a/demo_files/examples/horse-statue.png b/demo_files/examples/horse-statue.png new file mode 100644 index 0000000000000000000000000000000000000000..60e2edea4f9ba7cb2074061740f92ec678c74e9c --- /dev/null +++ b/demo_files/examples/horse-statue.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9c00f726efe9490b02d4c232293b629e0146dad6ce1ff8e22da8102345c5fe9 +size 222266 diff --git a/demo_files/examples/penguin.png b/demo_files/examples/penguin.png new file mode 100644 index 0000000000000000000000000000000000000000..e3469f2f8b6164543a49e68cb569753b1581378a --- /dev/null +++ b/demo_files/examples/penguin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a1667d874e9379a8d36e676fb80327bd7b5d3673cb77d7d4cf27bb53408fb98 +size 659119 diff --git a/demo_files/examples/pot.png b/demo_files/examples/pot.png new file mode 100644 index 0000000000000000000000000000000000000000..0bb9f60f9493c91b46eb8e8570f4c05b367df533 --- /dev/null +++ b/demo_files/examples/pot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32d5d8c110646a46ca24a4d6994cb848ef79cc7ad78dcc7419be0e6f02476a86 +size 1205124 diff --git a/demo_files/examples/raccoon_wizard.png b/demo_files/examples/raccoon_wizard.png new file mode 100644 index 0000000000000000000000000000000000000000..0adbb4ec54a94c70dcebaef5d4981ad1be959ef2 --- /dev/null +++ b/demo_files/examples/raccoon_wizard.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32cc3850d9f48548882c7b148e508e8ab149bc4f363611e9739adcbd38e8b16d +size 774200 diff --git a/demo_files/examples/stylized-rocks.png b/demo_files/examples/stylized-rocks.png new file mode 100644 index 0000000000000000000000000000000000000000..b847cfbad8f80a6286aae660f2273c0bc963ac3a --- /dev/null +++ b/demo_files/examples/stylized-rocks.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:386c3be3a6f24ee52e13f130c1ebc02a1bc46eb2c0ebe90d79ce6f38751f0fc6 +size 439074 diff --git a/demo_files/hdri/abandoned_tiled_room_1k.hdr b/demo_files/hdri/abandoned_tiled_room_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..cae9a42077b3883152061c8ca79f2520fc8bfae3 Binary files /dev/null and b/demo_files/hdri/abandoned_tiled_room_1k.hdr differ diff --git a/demo_files/hdri/metro_noord_1k.hdr b/demo_files/hdri/metro_noord_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..58deb543ef46e3edc329430b7670d96bd66fe387 Binary files /dev/null and b/demo_files/hdri/metro_noord_1k.hdr differ diff --git a/demo_files/hdri/neon_photostudio_1k.hdr b/demo_files/hdri/neon_photostudio_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..e3cab8074fa1047c7ce30dfa1df3c43354b4813f Binary files /dev/null and b/demo_files/hdri/neon_photostudio_1k.hdr differ diff --git a/demo_files/hdri/peppermint_powerplant_1k.hdr b/demo_files/hdri/peppermint_powerplant_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..632d5368b86cb992905fed73d2e9ee2e7302d597 Binary files /dev/null and b/demo_files/hdri/peppermint_powerplant_1k.hdr differ diff --git a/demo_files/hdri/rainforest_trail_1k.hdr b/demo_files/hdri/rainforest_trail_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..6ba95544ad9523f2959c134af86bbd58cc9cf2a1 Binary files /dev/null and b/demo_files/hdri/rainforest_trail_1k.hdr differ diff --git a/demo_files/hdri/studio_small_08_1k.hdr b/demo_files/hdri/studio_small_08_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..7d9e8c400cc487e99e13ec97916c0bb134acb423 Binary files /dev/null and b/demo_files/hdri/studio_small_08_1k.hdr differ diff --git a/demo_files/hdri/urban_alley_01_1k.hdr b/demo_files/hdri/urban_alley_01_1k.hdr new file mode 100644 index 0000000000000000000000000000000000000000..cc2abf89a4b01ca3e4c203bacf09f15510774cfe Binary files /dev/null and b/demo_files/hdri/urban_alley_01_1k.hdr differ diff --git a/demo_files/turntable.gif b/demo_files/turntable.gif new file mode 100644 index 0000000000000000000000000000000000000000..e9e5013dc71339f2f8c0a224b1c9a1073e6af82b --- /dev/null +++ b/demo_files/turntable.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffb5cfca3da84a569de41535781dfc6103834b99207136eb6cbf72d097799c6c +size 7578581 diff --git a/demo_files/workflows/spar3d_example.json b/demo_files/workflows/spar3d_example.json new file mode 100644 index 0000000000000000000000000000000000000000..4af1c0fd42f306e8c8c05c0e11364877a8c7e9ea --- /dev/null +++ b/demo_files/workflows/spar3d_example.json @@ -0,0 +1,263 @@ +{ + "last_node_id": 17, + "last_link_id": 18, + "nodes": [ + { + "id": 10, + "type": "SPAR3DLoader", + "pos": [ + 52.92446517944336, + 394.328369140625 + ], + "size": [ + 210, + 26 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "spar3d_model", + "type": "SPAR3D_MODEL", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "SPAR3DLoader" + }, + "widgets_values": [] + }, + { + "id": 13, + "type": "LoadImage", + "pos": [ + -43.437347412109375, + 482.89678955078125 + ], + "size": [ + 315, + 314 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 11 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": [ + 16 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "cat1.png", + "image" + ] + }, + { + "id": 16, + "type": "InvertMask", + "pos": [ + 377.1180419921875, + 605.384765625 + ], + "size": [ + 210, + 26 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 16 + } + ], + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 17 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "InvertMask" + }, + "widgets_values": [] + }, + { + "id": 17, + "type": "SPAR3DSave", + "pos": [ + 1133.669921875, + 439.6551513671875 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "mesh", + "type": "MESH", + "link": 18 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "SPAR3DSave" + }, + "widgets_values": [ + "SPAR3D" + ] + }, + { + "id": 11, + "type": "SPAR3DSampler", + "pos": [ + 673.0637817382812, + 441.2229309082031 + ], + "size": [ + 315, + 286 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "SPAR3D_MODEL", + "link": 10 + }, + { + "name": "image", + "type": "IMAGE", + "link": 11 + }, + { + "name": "mask", + "type": "MASK", + "link": 17, + "shape": 7 + }, + { + "name": "pointcloud", + "type": "POINTCLOUD", + "link": null, + "shape": 7 + } + ], + "outputs": [ + { + "name": "mesh", + "type": "MESH", + "links": [ + 18 + ], + "slot_index": 0 + }, + { + "name": "pointcloud", + "type": "POINTCLOUD", + "links": null + } + ], + "properties": { + "Node name for S&R": "SPAR3DSampler" + }, + "widgets_values": [ + 1.3, + 1024, + "none", + 1000, + 3, + 3727502160, + "randomize", + "none" + ] + } + ], + "links": [ + [ + 10, + 10, + 0, + 11, + 0, + "SPAR3D_MODEL" + ], + [ + 11, + 13, + 0, + 11, + 1, + "IMAGE" + ], + [ + 16, + 13, + 1, + 16, + 0, + "MASK" + ], + [ + 17, + 16, + 0, + 11, + 2, + "MASK" + ], + [ + 18, + 11, + 0, + 17, + 0, + "MESH" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.953502721998243, + "offset": [ + 266.21995970220667, + 116.75398112171928 + ] + } + }, + "version": 0.4 +} diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..4a768df00c73d517bb5b99376a2bac47e4b266de --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,792 @@ +import os + +os.system("pip install ./texture_baker/ ./uv_unwrapper/") + +import random +import tempfile +import time +from contextlib import nullcontext +from functools import lru_cache +from typing import Any + +import gradio as gr +import numpy as np +import torch +import trimesh +from gradio_litmodel3d import LitModel3D +from gradio_pointcloudeditor import PointCloudEditor +from PIL import Image +from transparent_background import Remover + +import spar3d.utils as spar3d_utils +from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE +from spar3d.system import SPAR3D + +os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio") + +bg_remover = Remover() # default setting + +COND_WIDTH = 512 +COND_HEIGHT = 512 +COND_DISTANCE = 2.2 +COND_FOVY = 0.591627 +BACKGROUND_COLOR = [0.5, 0.5, 0.5] + +# Cached. Doesn't change +c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE) +intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad( + COND_FOVY, COND_HEIGHT, COND_WIDTH +) + +generated_files = [] + +# Delete previous gradio temp dir folder +if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): + print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}") + import shutil + + shutil.rmtree(os.environ["GRADIO_TEMP_DIR"]) + +device = spar3d_utils.get_device() + +model = SPAR3D.from_pretrained( + "stabilityai/stable-point-aware-3d", + config_name="config.yaml", + weight_name="model.safetensors", +) +model.eval() +model = model.to(device) + +example_files = [ + os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") +] + + +def forward_model( + batch, + system, + guidance_scale=3.0, + seed=0, + device="cuda", + remesh_option="none", + vertex_count=-1, + texture_resolution=1024, +): + batch_size = batch["rgb_cond"].shape[0] + + # prepare the condition for point cloud generation + # set seed + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + cond_tokens = system.forward_pdiff_cond(batch) + + if "pc_cond" not in batch: + sample_iter = system.sampler.sample_batch_progressive( + batch_size, + cond_tokens, + guidance_scale=guidance_scale, + device=device, + ) + for x in sample_iter: + samples = x["xstart"] + batch["pc_cond"] = samples.permute(0, 2, 1).float() + batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"]) + + # subsample to the 512 points + batch["pc_cond"] = batch["pc_cond"][ + :, torch.randperm(batch["pc_cond"].shape[1])[:512] + ] + + # get the point cloud + xyz = batch["pc_cond"][0, :, :3].cpu().numpy() + color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8) + pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb) + + # forward for the final mesh + trimesh_mesh, _glob_dict = model.generate_mesh( + batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count + ) + trimesh_mesh = trimesh_mesh[0] + + return trimesh_mesh, pc_rgb_trimesh + + +def run_model( + input_image, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + vertex_count, + texture_resolution, +): + start = time.time() + with torch.no_grad(): + with ( + torch.autocast(device_type=device, dtype=torch.float16) + if "cuda" in device + else nullcontext() + ): + model_batch = create_batch(input_image) + model_batch = {k: v.to(device) for k, v in model_batch.items()} + + if pc_cond is not None: + # Check if pc_cond is a list + if isinstance(pc_cond, list): + cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6) + xyz = cond_tensor[:, :3] + color_rgb = cond_tensor[:, 3:] + elif isinstance(pc_cond, dict): + xyz = torch.tensor(pc_cond["positions"]).float().cuda() + color_rgb = torch.tensor(pc_cond["colors"]).float().cuda() + else: + xyz = torch.tensor(pc_cond.vertices).float().cuda() + color_rgb = ( + torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0 + ) + model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze( + 0 + ) + # sub-sample the point cloud to the target number of points + if model_batch["pc_cond"].shape[1] > 512: + idx = torch.randperm(model_batch["pc_cond"].shape[1])[:512] + model_batch["pc_cond"] = model_batch["pc_cond"][:, idx] + elif model_batch["pc_cond"].shape[1] < 512: + num_points = model_batch["pc_cond"].shape[1] + gr.Warning( + f"The uploaded point cloud should have at least 512 points. This point cloud only has {num_points}. Results may be worse." + ) + pad = 512 - num_points + sampled_idx = torch.randint( + 0, model_batch["pc_cond"].shape[1], (pad,) + ) + model_batch["pc_cond"] = torch.cat( + [ + model_batch["pc_cond"], + model_batch["pc_cond"][:, sampled_idx], + ], + dim=1, + ) + + trimesh_mesh, trimesh_pc = forward_model( + model_batch, + model, + guidance_scale=guidance_scale, + seed=random_seed, + device="cuda", + remesh_option=remesh_option.lower(), + vertex_count=vertex_count, + texture_resolution=texture_resolution, + ) + + # Create new tmp file + temp_dir = tempfile.mkdtemp() + tmp_file = os.path.join(temp_dir, "mesh.glb") + + trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True) + generated_files.append(tmp_file) + + tmp_file_pc = os.path.join(temp_dir, "points.ply") + trimesh_pc.export(tmp_file_pc) + generated_files.append(tmp_file_pc) + + print("Generation took:", time.time() - start, "s") + + return tmp_file, tmp_file_pc, trimesh_pc + + +def create_batch(input_image: Image) -> dict[str, Any]: + img_cond = ( + torch.from_numpy( + np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) + / 255.0 + ) + .float() + .clip(0, 1) + ) + mask_cond = img_cond[:, :, -1:] + rgb_cond = torch.lerp( + torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond + ) + + batch_elem = { + "rgb_cond": rgb_cond, + "mask_cond": mask_cond, + "c2w_cond": c2w_cond.unsqueeze(0), + "intrinsic_cond": intrinsic.unsqueeze(0), + "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), + } + # Add batch dim + batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} + return batched + + +@lru_cache +def checkerboard(squares: int, size: int, min_value: float = 0.5): + base = np.zeros((squares, squares)) + min_value + base[1::2, ::2] = 1 + base[::2, 1::2] = 1 + + repeat_mult = size // squares + return ( + base.repeat(repeat_mult, axis=0) + .repeat(repeat_mult, axis=1)[:, :, None] + .repeat(3, axis=-1) + ) + + +def remove_background(input_image: Image) -> Image: + return bg_remover.process(input_image.convert("RGB")) + + +def show_mask_img(input_image: Image) -> Image: + img_numpy = np.array(input_image) + alpha = img_numpy[:, :, 3] / 255.0 + chkb = checkerboard(32, 512) * 255 + new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None]) + return Image.fromarray(new_img.astype(np.uint8), mode="RGB") + + +def process_model_run( + background_state, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + vertex_count_type, + vertex_count, + texture_resolution, +): + # Adjust vertex count based on selection + final_vertex_count = ( + -1 + if vertex_count_type == "Keep Vertex Count" + else ( + vertex_count // 2 + if vertex_count_type == "Target Face Count" + else vertex_count + ) + ) + print( + f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}" + ) + + glb_file, pc_file, pc_plot = run_model( + background_state, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + final_vertex_count, + texture_resolution, + ) + # Create a single float list of x y z r g b + point_list = [] + for i in range(pc_plot.vertices.shape[0]): + point_list.extend( + [ + pc_plot.vertices[i, 0], + pc_plot.vertices[i, 1], + pc_plot.vertices[i, 2], + pc_plot.colors[i, 0] / 255.0, + pc_plot.colors[i, 1] / 255.0, + pc_plot.colors[i, 2] / 255.0, + ] + ) + + return glb_file, pc_file, point_list + + +def regenerate_run( + background_state, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + vertex_count_type, + vertex_count, + texture_resolution, +): + glb_file, pc_file, point_list = process_model_run( + background_state, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + vertex_count_type, + vertex_count, + texture_resolution, + ) + return ( + gr.update(), # run_btn + gr.update(), # img_proc_state + gr.update(), # background_remove_state + gr.update(), # preview_removal + gr.update(value=glb_file, visible=True), # output_3d + gr.update(visible=True), # hdr_row + gr.update(visible=True), # point_cloud_row + gr.update(value=point_list), # point_cloud_editor + gr.update(value=pc_file), # pc_download + gr.update(visible=False), # regenerate_btn + ) + + +def run_button( + run_btn, + input_image, + background_state, + foreground_ratio, + no_crop, + guidance_scale, + random_seed, + pc_upload, + pc_cond_file, + remesh_option, + vertex_count_type, + vertex_count, + texture_resolution, +): + if run_btn == "Run": + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + if pc_upload: + # make sure the pc_cond_file has been uploaded + try: + pc_cond = trimesh.load(pc_cond_file.name) + except Exception: + raise gr.Error( + "Please upload a valid point cloud ply file as condition." + ) + else: + pc_cond = None + + glb_file, pc_file, pc_list = process_model_run( + background_state, + guidance_scale, + random_seed, + pc_cond, + remesh_option, + vertex_count_type, + vertex_count, + texture_resolution, + ) + + if torch.cuda.is_available(): + print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB") + elif torch.backends.mps.is_available(): + print( + "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB" + ) + + return ( + gr.update(), # run_btn + gr.update(), # img_proc_state + gr.update(), # background_remove_state + gr.update(), # preview_removal + gr.update(value=glb_file, visible=True), # output_3d + gr.update(visible=True), # hdr_row + gr.update(visible=True), # point_cloud_row + gr.update(value=pc_list), # point_cloud_editor + gr.update(value=pc_file), # pc_download + gr.update(visible=False), # regenerate_btn + ) + + elif run_btn == "Remove Background": + rem_removed = remove_background(input_image) + + fr_res = spar3d_utils.foreground_crop( + rem_removed, + crop_ratio=foreground_ratio, + newsize=(COND_WIDTH, COND_HEIGHT), + no_crop=no_crop, + ) + + return ( + gr.update(value="Run", visible=True), # run_btn + rem_removed, # img_proc_state, + fr_res, # background_remove_state + gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal + gr.update(value=None, visible=False), # output_3d + gr.update(visible=False), # hdr_row + gr.update(visible=False), # point_cloud_row + gr.update(value=None), # point_cloud_editor + gr.update(value=None), # pc_download + gr.update(visible=False), # regenerate_btn + ) + + +def requires_bg_remove(image, fr, no_crop): + if image is None: + return ( + gr.update(visible=False, value="Run"), # run_Btn + None, # img_proc_state + None, # background_remove_state + gr.update(value=None, visible=False), # preview_removal + gr.update(value=None, visible=False), # output_3d + gr.update(visible=False), # hdr_row + gr.update(visible=False), # point_cloud_row + gr.update(value=None), # point_cloud_editor + gr.update(value=None), # pc_download + gr.update(visible=False), # regenerate_btn + ) + alpha_channel = np.array(image.getchannel("A")) + min_alpha = alpha_channel.min() + + if min_alpha == 0: + print("Already has alpha") + fr_res = spar3d_utils.foreground_crop( + image, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop + ) + return ( + gr.update(value="Run", visible=True), # run_Btn + image, # img_proc_state + fr_res, # background_remove_state + gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal + gr.update(value=None, visible=False), # output_3d + gr.update(visible=False), # hdr_row + gr.update(visible=False), # point_cloud_row + gr.update(value=None), # point_cloud_editor + gr.update(value=None), # pc_download + gr.update(visible=False), # regenerate_btn + ) + return ( + gr.update(value="Remove Background", visible=True), # run_Btn + None, # img_proc_state + None, # background_remove_state + gr.update(value=None, visible=False), # preview_removal + gr.update(value=None, visible=False), # output_3d + gr.update(visible=False), # hdr_row + gr.update(visible=False), # point_cloud_row + gr.update(value=None), # point_cloud_editor + gr.update(value=None), # pc_download + gr.update(visible=False), # regenerate_btn + ) + + +def update_foreground_ratio(img_proc, fr, no_crop): + foreground_res = spar3d_utils.foreground_crop( + img_proc, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop + ) + return ( + foreground_res, + gr.update(value=show_mask_img(foreground_res)), + ) + + +def update_resolution_controls(remesh_choice, vertex_count_type): + show_controls = remesh_choice.lower() != "none" + show_vertex_count = vertex_count_type != "Keep Vertex Count" + return ( + gr.update(visible=show_controls), # vertex_count_type + gr.update(visible=show_controls and show_vertex_count), # vertex_count_slider + ) + + +with gr.Blocks() as demo: + img_proc_state = gr.State() + background_remove_state = gr.State() + gr.Markdown( + """ + # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images + + SPAR3D is a state-of-the-art method for 3D mesh reconstruction from a single image. This demo allows you to upload an image and generate a 3D mesh model from it. A feature of SPAR3D is it generates point clouds as intermediate representation before producing the mesh. You can edit the point cloud to adjust the final mesh. We provide a simple point cloud editor in this demo, where you can drag, recolor and rescale the point clouds. If you have more advanced editing needs (e.g. box selection, duplication, local streching, etc.), you can download the point cloud and edit it in softwares such as MeshLab or Blender. The edited point cloud can then be uploaded to this demo to generate a new 3D model by checking the "Point cloud upload" box. + + **Tips** + + 1. If the image does not have a valid alpha channel, it will go through the background removal step. Our built-in background removal can be inaccurate sometimes, which will result in poor mesh quality. In such cases, you can use external background removal tools to obtain a RGBA image before uploading here. + 2. You can adjust the foreground ratio to control the size of the foreground object. This may have major impact on the final mesh. + 3. Guidance scale controls the strength of the image condition in the point cloud generation process. A higher value may result in higher mesh fidelity, but the variability by changing the random seed will be lower. Note that the guidance scale and the seed are not effective when the point cloud is manually uploaded. + 4. Our online editor supports multi-selection by holding down the shift key. This allows you to recolor multiple points at once. + 5. The editing should mainly alter the unseen parts of the object. Visible parts can be edited, but the edits should be consistent with the image. Editing the visible parts in a way that contradicts the image may result in poor mesh quality. + 6. You can upload your own HDR environment map to light the 3D model. + """ + ) + with gr.Row(variant="panel"): + with gr.Column(): + with gr.Row(): + input_img = gr.Image( + type="pil", label="Input Image", sources="upload", image_mode="RGBA" + ) + preview_removal = gr.Image( + label="Preview Background Removal", + type="pil", + image_mode="RGB", + interactive=False, + visible=False, + ) + + gr.Markdown("### Input Controls") + with gr.Group(): + with gr.Row(): + no_crop = gr.Checkbox(label="No cropping", value=False) + pc_upload = gr.Checkbox(label="Point cloud upload", value=False) + + pc_cond_file = gr.File( + label="Point Cloud Upload", + file_types=[".ply"], + file_count="single", + visible=False, + ) + + foreground_ratio = gr.Slider( + label="Padding Ratio", + minimum=1.0, + maximum=2.0, + value=1.3, + step=0.05, + ) + + pc_upload.change( + lambda x: gr.update(visible=x), + inputs=pc_upload, + outputs=[pc_cond_file], + ) + + no_crop.change( + update_foreground_ratio, + inputs=[img_proc_state, foreground_ratio, no_crop], + outputs=[background_remove_state, preview_removal], + ) + + foreground_ratio.change( + update_foreground_ratio, + inputs=[img_proc_state, foreground_ratio, no_crop], + outputs=[background_remove_state, preview_removal], + ) + + gr.Markdown("### Point Diffusion Controls") + with gr.Group(): + guidance_scale = gr.Slider( + label="Guidance Scale", + minimum=1.0, + maximum=10.0, + value=3.0, + step=1.0, + ) + + random_seed = gr.Slider( + label="Seed", + minimum=0, + maximum=10000, + value=0, + step=1, + ) + + no_remesh = not TRIANGLE_REMESH_AVAILABLE and not QUAD_REMESH_AVAILABLE + gr.Markdown( + "### Texture Controls" + if no_remesh + else "### Meshing and Texture Controls" + ) + with gr.Group(): + remesh_choices = ["None"] + if TRIANGLE_REMESH_AVAILABLE: + remesh_choices.append("Triangle") + if QUAD_REMESH_AVAILABLE: + remesh_choices.append("Quad") + + remesh_option = gr.Radio( + choices=remesh_choices, + label="Remeshing", + value="None", + visible=not no_remesh, + ) + + vertex_count_type = gr.Radio( + choices=[ + "Keep Vertex Count", + "Target Vertex Count", + "Target Face Count", + ], + label="Mesh Resolution Control", + value="Keep Vertex Count", + visible=False, + ) + + vertex_count_slider = gr.Slider( + label="Target Count", + minimum=0, + maximum=20000, + value=2000, + visible=False, + ) + + texture_size = gr.Slider( + label="Texture Size", + minimum=512, + maximum=2048, + value=1024, + step=256, + visible=True, + ) + + remesh_option.change( + update_resolution_controls, + inputs=[remesh_option, vertex_count_type], + outputs=[vertex_count_type, vertex_count_slider], + ) + + vertex_count_type.change( + update_resolution_controls, + inputs=[remesh_option, vertex_count_type], + outputs=[vertex_count_type, vertex_count_slider], + ) + + run_btn = gr.Button("Run", variant="primary", visible=False) + + with gr.Column(): + with gr.Group(visible=False) as point_cloud_row: + point_size_slider = gr.Slider( + label="Point Size", + minimum=0.01, + maximum=1.0, + value=0.2, + step=0.01, + ) + point_cloud_editor = PointCloudEditor( + up_axis="Z", + forward_axis="X", + lock_scale_z=True, + lock_scale_y=True, + visible=True, + ) + + pc_download = gr.File( + label="Point Cloud Download", + file_types=[".ply"], + file_count="single", + ) + point_size_slider.change( + fn=lambda x: gr.update(point_size=x), + inputs=point_size_slider, + outputs=point_cloud_editor, + ) + + regenerate_btn = gr.Button( + "Re-run with point cloud", variant="primary", visible=False + ) + + output_3d = LitModel3D( + label="3D Model", + visible=False, + clear_color=[0.0, 0.0, 0.0, 0.0], + tonemapping="aces", + contrast=1.0, + scale=1.0, + ) + with gr.Column(visible=False, scale=1.0) as hdr_row: + gr.Markdown( + """## HDR Environment Map + + Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps. + """ + ) + + with gr.Row(): + hdr_illumination_file = gr.File( + label="HDR Env Map", + file_types=[".hdr"], + file_count="single", + ) + example_hdris = [ + os.path.join("demo_files/hdri", f) + for f in os.listdir("demo_files/hdri") + ] + hdr_illumination_example = gr.Examples( + examples=example_hdris, + inputs=hdr_illumination_file, + ) + + hdr_illumination_file.change( + lambda x: gr.update(env_map=x.name if x is not None else None), + inputs=hdr_illumination_file, + outputs=[output_3d], + ) + + examples = gr.Examples( + examples=example_files, inputs=input_img, examples_per_page=11 + ) + + input_img.change( + requires_bg_remove, + inputs=[input_img, foreground_ratio, no_crop], + outputs=[ + run_btn, + img_proc_state, + background_remove_state, + preview_removal, + output_3d, + hdr_row, + point_cloud_row, + point_cloud_editor, + pc_download, + regenerate_btn, + ], + ) + + point_cloud_editor.edit( + fn=lambda _x: gr.update(visible=True), + inputs=point_cloud_editor, + outputs=regenerate_btn, + ) + + regenerate_btn.click( + regenerate_run, + inputs=[ + background_remove_state, + guidance_scale, + random_seed, + point_cloud_editor, + remesh_option, + vertex_count_type, + vertex_count_slider, + texture_size, + ], + outputs=[ + run_btn, + img_proc_state, + background_remove_state, + preview_removal, + output_3d, + hdr_row, + point_cloud_row, + point_cloud_editor, + pc_download, + regenerate_btn, + ], + ) + + run_btn.click( + run_button, + inputs=[ + run_btn, + input_img, + background_remove_state, + foreground_ratio, + no_crop, + guidance_scale, + random_seed, + pc_upload, + pc_cond_file, + remesh_option, + vertex_count_type, + vertex_count_slider, + texture_size, + ], + outputs=[ + run_btn, + img_proc_state, + background_remove_state, + preview_removal, + output_3d, + hdr_row, + point_cloud_row, + point_cloud_editor, + pc_download, + regenerate_btn, + ], + ) + +demo.queue().launch() diff --git a/load/tets/160_tets.npz b/load/tets/160_tets.npz new file mode 100644 index 0000000000000000000000000000000000000000..021722f535100c50bdb584957303a1483136929c --- /dev/null +++ b/load/tets/160_tets.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9 +size 15408790 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5afd7f3b246f72f62a70d1335db18e0c03361ea3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +einops==0.7.0 +jaxtyping==0.2.31 +omegaconf==2.3.0 +transformers==4.42.3 +loralib==0.1.2 +git+https://github.com/openai/CLIP.git +git+https://github.com/SunzeY/AlphaCLIP.git +trimesh==4.4.1 +numpy==1.26.4 +huggingface-hub==0.23.4 +transparent-background==1.3.3 +gradio==4.43.0 +gradio-litmodel3d==0.0.1 +gradio-pointcloudeditor==0.0.9 +gpytoolbox==0.2.0 +# ./texture_baker/ +# ./uv_unwrapper/ diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000000000000000000000000000000000000..36a85dc5e2359d3a3829d9260001e7ef4b02bce5 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,3 @@ +[lint] +ignore = ["F722", "F821"] +extend-select = ["I"] diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a334784ff109aaf0b36d482c3a9fc9d2c1bf73ce --- /dev/null +++ b/run.py @@ -0,0 +1,180 @@ +import argparse +import os +from contextlib import nullcontext + +import torch +from PIL import Image +from tqdm import tqdm +from transparent_background import Remover + +from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE +from spar3d.system import SPAR3D +from spar3d.utils import foreground_crop, get_device, remove_background + + +def check_positive(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) + return ivalue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "image", type=str, nargs="+", help="Path to input image(s) or folder." + ) + parser.add_argument( + "--device", + default=get_device(), + type=str, + help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'", + ) + parser.add_argument( + "--pretrained-model", + default="stabilityai/spar3d", + type=str, + help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/spar3d'", + ) + parser.add_argument( + "--foreground-ratio", + default=1.3, + type=float, + help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85", + ) + parser.add_argument( + "--output-dir", + default="output/", + type=str, + help="Output directory to save the results. Default: 'output/'", + ) + parser.add_argument( + "--texture-resolution", + default=1024, + type=int, + help="Texture atlas resolution. Default: 1024", + ) + + remesh_choices = ["none"] + if TRIANGLE_REMESH_AVAILABLE: + remesh_choices.append("triangle") + if QUAD_REMESH_AVAILABLE: + remesh_choices.append("quad") + parser.add_argument( + "--remesh_option", + choices=remesh_choices, + default="none", + help="Remeshing option", + ) + if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: + parser.add_argument( + "--reduction_count_type", + choices=["keep", "vertex", "faces"], + default="keep", + help="Vertex count type", + ) + parser.add_argument( + "--target_count", + type=check_positive, + help="Selected target count.", + default=2000, + ) + parser.add_argument( + "--batch_size", default=1, type=int, help="Batch size for inference" + ) + args = parser.parse_args() + + # Ensure args.device contains cuda + devices = ["cuda", "mps", "cpu"] + if not any(args.device in device for device in devices): + raise ValueError("Invalid device. Use cuda, mps or cpu") + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + device = args.device + if not (torch.cuda.is_available() or torch.backends.mps.is_available()): + device = "cpu" + + print("Device used: ", device) + + model = SPAR3D.from_pretrained( + args.pretrained_model, + config_name="config.yaml", + weight_name="model.safetensors", + ) + model.to(device) + model.eval() + + bg_remover = Remover(device=device) + images = [] + idx = 0 + for image_path in args.image: + + def handle_image(image_path, idx): + image = remove_background( + Image.open(image_path).convert("RGBA"), bg_remover + ) + image = foreground_crop(image, args.foreground_ratio) + os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True) + image.save(os.path.join(output_dir, str(idx), "input.png")) + images.append(image) + + if os.path.isdir(image_path): + image_paths = [ + os.path.join(image_path, f) + for f in os.listdir(image_path) + if f.endswith((".png", ".jpg", ".jpeg")) + ] + for image_path in image_paths: + handle_image(image_path, idx) + idx += 1 + else: + handle_image(image_path, idx) + idx += 1 + + vertex_count = ( + -1 + if args.reduction_count_type == "keep" + else ( + args.target_count + if args.reduction_count_type == "vertex" + else args.target_count // 2 + ) + ) + + for i in tqdm(range(0, len(images), args.batch_size)): + image = images[i : i + args.batch_size] + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + with ( + torch.autocast(device_type=device, dtype=torch.float16) + if "cuda" in device + else nullcontext() + ): + mesh, glob_dict = model.run_image( + image, + bake_resolution=args.texture_resolution, + remesh=args.remesh_option, + vertex_count=args.target_vertex_count, + return_points=True, + ) + if torch.cuda.is_available(): + print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB") + elif torch.backends.mps.is_available(): + print( + "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB" + ) + + if len(image) == 1: + out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb") + mesh.export(out_mesh_path, include_normals=True) + out_points_path = os.path.join(output_dir, str(i), "points.ply") + glob_dict["point_clouds"][0].export(out_points_path) + else: + for j in range(len(mesh)): + out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb") + mesh[j].export(out_mesh_path, include_normals=True) + out_points_path = os.path.join(output_dir, str(i + j), "points.ply") + glob_dict["point_clouds"][j].export(out_points_path) diff --git a/spar3d/models/camera.py b/spar3d/models/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad9c932ed4812f624ea8b5c6e9c75a4bf0dd483 --- /dev/null +++ b/spar3d/models/camera.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn + +from spar3d.models.utils import BaseModule + + +class LinearCameraEmbedder(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 25 + out_channels: int = 768 + conditions: List[str] = field(default_factory=list) + + cfg: Config + + def configure(self) -> None: + self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels) + + def forward(self, **kwargs): + cond_tensors = [] + for cond_name in self.cfg.conditions: + assert cond_name in kwargs + cond = kwargs[cond_name] + # cond in shape (B, Nv, ...) + cond_tensors.append(cond.view(*cond.shape[:2], -1)) + cond_tensor = torch.cat(cond_tensors, dim=-1) + assert cond_tensor.shape[-1] == self.cfg.in_channels + embedding = self.linear(cond_tensor) + return embedding diff --git a/spar3d/models/diffusion/gaussian_diffusion.py b/spar3d/models/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..83c4060375a625da84ca7950d776a51eb0959f36 --- /dev/null +++ b/spar3d/models/diffusion/gaussian_diffusion.py @@ -0,0 +1,524 @@ +# -------------------------------------------------------- +# Adapted from: https://github.com/openai/point-e +# Licensed under the MIT License +# Copyright (c) 2022 OpenAI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# -------------------------------------------------------- + +import math +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch as th + + +def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9): + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + v_start = sigmoid(start / tau) + v_end = sigmoid(end / tau) + output = sigmoid((t * (end - start) + start) / tau) + output = (v_end - output) / (v_end - v_start) + return np.clip(output, clip_min, 1.0) + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, exp_p=12): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + elif schedule_name == "sigmoid": + # Sigmoid schedule passed through betas_for_alpha_bar + return betas_for_alpha_bar( + num_diffusion_timesteps, lambda t: sigmoid_schedule(t) + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + elif section_counts.startswith("exact"): + res = set(int(x) for x in section_counts[len("exact") :].split(",")) + for x in res: + if x < 0 or x >= num_timesteps: + raise ValueError(f"timestep out of bounds: {x}") + return res + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """Extract values from a 1-D numpy array for a batch of indices.""" + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) + + +class GaussianDiffusion: + """ + Utilities for sampling from Gaussian diffusion models. + """ + + def __init__( + self, + *, + betas: Sequence[float], + model_mean_type: str, + model_var_type: str, + channel_scales: Optional[np.ndarray] = None, + channel_biases: Optional[np.ndarray] = None, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.channel_scales = channel_scales + self.channel_biases = channel_biases + + # Use float64 for accuracy + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def scale_channels(self, x: th.Tensor) -> th.Tensor: + """Apply channel-wise scaling.""" + if self.channel_scales is not None: + x = x * th.from_numpy(self.channel_scales).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + if self.channel_biases is not None: + x = x + th.from_numpy(self.channel_biases).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + return x + + def unscale_channels(self, x: th.Tensor) -> th.Tensor: + """Remove channel-wise scaling.""" + if self.channel_biases is not None: + x = x - th.from_numpy(self.channel_biases).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + if self.channel_scales is not None: + x = x / th.from_numpy(self.channel_scales).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + return x + + def unscale_out_dict( + self, out: Dict[str, Union[th.Tensor, Any]] + ) -> Dict[str, Union[th.Tensor, Any]]: + return { + k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) + for k, v in out.items() + } + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t). + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + + # Direct prediction of eps + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, prev_latent = model_output + model_kwargs["prev_latent"] = prev_latent + + # Convert model output to mean and variance + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + "fixed_large": ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + "fixed_small": ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + x = x.clamp( + -self.channel_scales[0] * 0.67, self.channel_scales[0] * 0.67 + ) + x[:, 3:] = x[:, 3:].clamp( + -self.channel_scales[3] * 0.5, self.channel_scales[3] * 0.5 + ) + return x + return x + + if self.model_mean_type == "x_prev": + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in ["x_start", "epsilon"]: + if self.model_mean_type == "x_start": + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + # print('p_mean_variance:', pred_xstart.min(), pred_xstart.max()) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples. + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + indices = list(range(self.num_timesteps))[::-1] + + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield self.unscale_out_dict(out) + img = out["sample"] + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + """ + + def __init__(self, use_timesteps: Iterable[int], **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + +class _WrappedModel: + """Helper class to wrap models for SpacedDiffusion.""" + + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + return self.model(x, new_ts, **kwargs) diff --git a/spar3d/models/diffusion/sampler.py b/spar3d/models/diffusion/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..df5abaf1dd81158b4f68d26ba53a11bf40277744 --- /dev/null +++ b/spar3d/models/diffusion/sampler.py @@ -0,0 +1,134 @@ +# -------------------------------------------------------- +# Adapted from: https://github.com/openai/point-e +# Licensed under the MIT License +# Copyright (c) 2022 OpenAI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# -------------------------------------------------------- + +from typing import Dict, Iterator + +import torch +import torch.nn as nn + +from .gaussian_diffusion import GaussianDiffusion + + +class PointCloudSampler: + """ + A wrapper around a model that produces conditional sample tensors. + """ + + def __init__( + self, + model: nn.Module, + diffusion: GaussianDiffusion, + num_points: int, + point_dim: int = 3, + guidance_scale: float = 3.0, + clip_denoised: bool = True, + sigma_min: float = 1e-3, + sigma_max: float = 120, + s_churn: float = 3, + ): + self.model = model + self.num_points = num_points + self.point_dim = point_dim + self.guidance_scale = guidance_scale + self.clip_denoised = clip_denoised + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.s_churn = s_churn + + self.diffusion = diffusion + + def sample_batch_progressive( + self, + batch_size: int, + condition: torch.Tensor, + noise=None, + device=None, + guidance_scale=None, + ) -> Iterator[Dict[str, torch.Tensor]]: + """ + Generate samples progressively using classifier-free guidance. + + Args: + batch_size: Number of samples to generate + condition: Conditioning tensor + noise: Optional initial noise tensor + device: Device to run on + guidance_scale: Optional override for guidance scale + + Returns: + Iterator of dicts containing intermediate samples + """ + if guidance_scale is None: + guidance_scale = self.guidance_scale + + sample_shape = (batch_size, self.point_dim, self.num_points) + + # Double the batch for classifier-free guidance + if guidance_scale != 1 and guidance_scale != 0: + condition = torch.cat([condition, torch.zeros_like(condition)], dim=0) + if noise is not None: + noise = torch.cat([noise, noise], dim=0) + model_kwargs = {"condition": condition} + + internal_batch_size = batch_size + if guidance_scale != 1 and guidance_scale != 0: + model = self._uncond_guide_model(self.model, guidance_scale) + internal_batch_size *= 2 + else: + model = self.model + + samples_it = self.diffusion.ddim_sample_loop_progressive( + model, + shape=(internal_batch_size, *sample_shape[1:]), + model_kwargs=model_kwargs, + device=device, + clip_denoised=self.clip_denoised, + noise=noise, + ) + + for x in samples_it: + samples = { + "xstart": x["pred_xstart"][:batch_size], + "xprev": x["sample"][:batch_size] if "sample" in x else x["x"], + } + yield samples + + def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module: + """ + Wraps the model for classifier-free guidance. + """ + + def model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = model(combined, ts, **kwargs) + + eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :] + cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) + half_eps = uncond_eps + scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + return model_fn diff --git a/spar3d/models/global_estimator/reni_estimator.py b/spar3d/models/global_estimator/reni_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd2ac5286ec6d52160824420179d28082d7d72d --- /dev/null +++ b/spar3d/models/global_estimator/reni_estimator.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + +from spar3d.models.illumination.reni.env_map import RENIEnvMap +from spar3d.models.utils import BaseModule + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)" + + def proj_u2a(u, a): + r""" + u: batch x 3 + a: batch x 3 + """ + inner_prod = torch.sum(u * a, dim=-1, keepdim=True) + norm2 = torch.sum(u**2, dim=-1, keepdim=True) + norm2 = torch.clamp(norm2, min=1e-8) + factor = inner_prod / (norm2 + 1e-10) + return factor * u + + x_raw, y_raw = d6[..., :3], d6[..., 3:] + + x = F.normalize(x_raw, dim=-1) + y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1) + z = torch.cross(x, y, dim=-1) + + return torch.stack((x, y, z), dim=-1) + + +class ReniLatentCodeEstimator(BaseModule): + @dataclass + class Config(BaseModule.Config): + triplane_features: int = 40 + + n_layers: int = 5 + hidden_features: int = 512 + activation: str = "relu" + + pool: str = "mean" + + reni_env_config: dict = field(default_factory=dict) + + cfg: Config + + def configure(self): + layers = [] + cur_features = self.cfg.triplane_features * 3 + for _ in range(self.cfg.n_layers): + layers.append( + nn.Conv2d( + cur_features, + self.cfg.hidden_features, + kernel_size=3, + padding=0, + stride=2, + ) + ) + layers.append(self.make_activation(self.cfg.activation)) + + cur_features = self.cfg.hidden_features + + self.layers = nn.Sequential(*layers) + + self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config) + self.latent_dim = self.reni_env_map.field.latent_dim + + self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3) + nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3) + + self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6) + nn.init.constant_(self.fc_rotations.bias, 0.0) + nn.init.normal_( + self.fc_rotations.weight, mean=0.0, std=0.01 + ) # Small variance here + + self.fc_scale = nn.Linear(self.cfg.hidden_features, 1) + nn.init.constant_(self.fc_scale.bias, 0.0) + nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01) # Small variance here + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + def forward( + self, + triplane: Float[Tensor, "B 3 F Ht Wt"], + ) -> dict[str, Any]: + x = self.layers( + triplane.reshape( + triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1] + ) + ) + x = x.mean(dim=[-2, -1]) + + latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3) + rotations = self.fc_rotations(x) + scale = self.fc_scale(x) + + env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale) + + return {"illumination": env_map["rgb"]} diff --git a/spar3d/models/illumination/reni/components/film_siren.py b/spar3d/models/illumination/reni/components/film_siren.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef78a368f9f4b9653365ee07adf1f39eaec227b --- /dev/null +++ b/spar3d/models/illumination/reni/components/film_siren.py @@ -0,0 +1,148 @@ +"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/.""" + +from typing import Optional + +import numpy as np +import torch +from torch import nn + + +def kaiming_leaky_init(m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + torch.nn.init.kaiming_normal_( + m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu" + ) + + +def frequency_init(freq): + def init(m): + with torch.no_grad(): + if isinstance(m, nn.Linear): + num_input = m.weight.size(-1) + m.weight.uniform_( + -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq + ) + + return init + + +def first_layer_film_sine_init(m): + with torch.no_grad(): + if isinstance(m, nn.Linear): + num_input = m.weight.size(-1) + m.weight.uniform_(-1 / num_input, 1 / num_input) + + +class CustomMappingNetwork(nn.Module): + def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim): + super().__init__() + + self.network = [] + + for _ in range(map_hidden_layers): + self.network.append(nn.Linear(in_features, map_hidden_dim)) + self.network.append(nn.LeakyReLU(0.2, inplace=True)) + in_features = map_hidden_dim + + self.network.append(nn.Linear(map_hidden_dim, map_output_dim)) + + self.network = nn.Sequential(*self.network) + + self.network.apply(kaiming_leaky_init) + with torch.no_grad(): + self.network[-1].weight *= 0.25 + + def forward(self, z): + frequencies_offsets = self.network(z) + frequencies = frequencies_offsets[ + ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") + ] + phase_shifts = frequencies_offsets[ + ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") : + ] + + return frequencies, phase_shifts + + +class FiLMLayer(nn.Module): + def __init__(self, input_dim, hidden_dim): + super().__init__() + self.layer = nn.Linear(input_dim, hidden_dim) + + def forward(self, x, freq, phase_shift): + x = self.layer(x) + freq = freq.expand_as(x) + phase_shift = phase_shift.expand_as(x) + return torch.sin(freq * x + phase_shift) + + +class FiLMSiren(nn.Module): + """FiLM Conditioned Siren network.""" + + def __init__( + self, + in_dim: int, + hidden_layers: int, + hidden_features: int, + mapping_network_in_dim: int, + mapping_network_layers: int, + mapping_network_features: int, + out_dim: int, + outermost_linear: bool = False, + out_activation: Optional[nn.Module] = None, + ) -> None: + super().__init__() + self.in_dim = in_dim + assert self.in_dim > 0 + self.out_dim = out_dim if out_dim is not None else hidden_features + self.hidden_layers = hidden_layers + self.hidden_features = hidden_features + self.mapping_network_in_dim = mapping_network_in_dim + self.mapping_network_layers = mapping_network_layers + self.mapping_network_features = mapping_network_features + self.outermost_linear = outermost_linear + self.out_activation = out_activation + + self.net = nn.ModuleList() + + self.net.append(FiLMLayer(self.in_dim, self.hidden_features)) + + for _ in range(self.hidden_layers - 1): + self.net.append(FiLMLayer(self.hidden_features, self.hidden_features)) + + self.final_layer = None + if self.outermost_linear: + self.final_layer = nn.Linear(self.hidden_features, self.out_dim) + self.final_layer.apply(frequency_init(25)) + else: + final_layer = FiLMLayer(self.hidden_features, self.out_dim) + self.net.append(final_layer) + + self.mapping_network = CustomMappingNetwork( + in_features=self.mapping_network_in_dim, + map_hidden_layers=self.mapping_network_layers, + map_hidden_dim=self.mapping_network_features, + map_output_dim=(len(self.net)) * self.hidden_features * 2, + ) + + self.net.apply(frequency_init(25)) + self.net[0].apply(first_layer_film_sine_init) + + def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts): + """Get conditiional frequencies and phase shifts from mapping network.""" + frequencies = frequencies * 15 + 30 + + for index, layer in enumerate(self.net): + start = index * self.hidden_features + end = (index + 1) * self.hidden_features + x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) + + x = self.final_layer(x) if self.final_layer is not None else x + output = self.out_activation(x) if self.out_activation is not None else x + return output + + def forward(self, x, conditioning_input): + """Forward pass.""" + frequencies, phase_shifts = self.mapping_network(conditioning_input) + return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts) diff --git a/spar3d/models/illumination/reni/components/siren.py b/spar3d/models/illumination/reni/components/siren.py new file mode 100644 index 0000000000000000000000000000000000000000..dee2431058efc31aab9a9c5e7da68d0e15965461 --- /dev/null +++ b/spar3d/models/illumination/reni/components/siren.py @@ -0,0 +1,118 @@ +"""Siren MLP https://www.vincentsitzmann.com/siren/""" + +from typing import Optional + +import numpy as np +import torch +from torch import nn + + +class SineLayer(nn.Module): + """ + Sine layer for the SIREN network. + """ + + def __init__( + self, in_features, out_features, bias=True, is_first=False, omega_0=30.0 + ): + super().__init__() + self.omega_0 = omega_0 + self.is_first = is_first + + self.in_features = in_features + self.linear = nn.Linear(in_features, out_features, bias=bias) + + self.init_weights() + + def init_weights(self): + with torch.no_grad(): + if self.is_first: + self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) + else: + self.linear.weight.uniform_( + -np.sqrt(6 / self.in_features) / self.omega_0, + np.sqrt(6 / self.in_features) / self.omega_0, + ) + + def forward(self, x): + return torch.sin(self.omega_0 * self.linear(x)) + + +class Siren(nn.Module): + """Siren network. + + Args: + in_dim: Input layer dimension + num_layers: Number of network layers + layer_width: Width of each MLP layer + out_dim: Output layer dimension. Uses layer_width if None. + activation: intermediate layer activation function. + out_activation: output activation function. + """ + + def __init__( + self, + in_dim: int, + hidden_layers: int, + hidden_features: int, + out_dim: Optional[int] = None, + outermost_linear: bool = False, + first_omega_0: float = 30, + hidden_omega_0: float = 30, + out_activation: Optional[nn.Module] = None, + ) -> None: + super().__init__() + self.in_dim = in_dim + assert self.in_dim > 0 + self.out_dim = out_dim if out_dim is not None else hidden_features + self.outermost_linear = outermost_linear + self.first_omega_0 = first_omega_0 + self.hidden_omega_0 = hidden_omega_0 + self.hidden_layers = hidden_layers + self.layer_width = hidden_features + self.out_activation = out_activation + + self.net = [] + self.net.append( + SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0) + ) + + for _ in range(hidden_layers): + self.net.append( + SineLayer( + hidden_features, + hidden_features, + is_first=False, + omega_0=hidden_omega_0, + ) + ) + + if outermost_linear: + final_layer = nn.Linear(hidden_features, self.out_dim) + + with torch.no_grad(): + final_layer.weight.uniform_( + -np.sqrt(6 / hidden_features) / hidden_omega_0, + np.sqrt(6 / hidden_features) / hidden_omega_0, + ) + + self.net.append(final_layer) + else: + self.net.append( + SineLayer( + hidden_features, + self.out_dim, + is_first=False, + omega_0=hidden_omega_0, + ) + ) + + if self.out_activation is not None: + self.net.append(self.out_activation) + + self.net = nn.Sequential(*self.net) + + def forward(self, model_input): + """Forward pass through the network""" + output = self.net(model_input) + return output diff --git a/spar3d/models/illumination/reni/components/transformer_decoder.py b/spar3d/models/illumination/reni/components/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..44646e06caeefeb5b13461c8af9e35a1993b0ed6 --- /dev/null +++ b/spar3d/models/illumination/reni/components/transformer_decoder.py @@ -0,0 +1,189 @@ +from typing import Optional + +import torch +from torch import nn + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + direction_input_dim: int, + conditioning_input_dim: int, + latent_dim: int, + num_heads: int, + ): + """ + Multi-Head Attention module. + + Args: + direction_input_dim (int): The input dimension of the directional input. + conditioning_input_dim (int): The input dimension of the conditioning input. + latent_dim (int): The latent dimension of the module. + num_heads (int): The number of heads to use in the attention mechanism. + """ + super().__init__() + assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = latent_dim // num_heads + self.scale = self.head_dim**-0.5 + + self.query = nn.Linear(direction_input_dim, latent_dim) + self.key = nn.Linear(conditioning_input_dim, latent_dim) + self.value = nn.Linear(conditioning_input_dim, latent_dim) + self.fc_out = nn.Linear(latent_dim, latent_dim) + + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the Multi-Head Attention module. + + Args: + query (torch.Tensor): The directional input tensor. + key (torch.Tensor): The conditioning input tensor for the keys. + value (torch.Tensor): The conditioning input tensor for the values. + + Returns: + torch.Tensor: The output tensor of the Multi-Head Attention module. + """ + batch_size = query.size(0) + + Q = ( + self.query(query) + .view(batch_size, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + K = ( + self.key(key) + .view(batch_size, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + V = ( + self.value(value) + .view(batch_size, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + attention = ( + torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale + ) + attention = torch.softmax(attention, dim=-1) + + out = torch.einsum("bnqh,bnhv->bnqv", [attention, V]) + out = ( + out.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.num_heads * self.head_dim) + ) + + out = self.fc_out(out).squeeze(1) + return out + + +class AttentionLayer(nn.Module): + def __init__( + self, + direction_input_dim: int, + conditioning_input_dim: int, + latent_dim: int, + num_heads: int, + ): + """ + Attention Layer module. + + Args: + direction_input_dim (int): The input dimension of the directional input. + conditioning_input_dim (int): The input dimension of the conditioning input. + latent_dim (int): The latent dimension of the module. + num_heads (int): The number of heads to use in the attention mechanism. + """ + super().__init__() + self.mha = MultiHeadAttention( + direction_input_dim, conditioning_input_dim, latent_dim, num_heads + ) + self.norm1 = nn.LayerNorm(latent_dim) + self.norm2 = nn.LayerNorm(latent_dim) + self.fc = nn.Sequential( + nn.Linear(latent_dim, latent_dim), + nn.ReLU(), + nn.Linear(latent_dim, latent_dim), + ) + + def forward( + self, directional_input: torch.Tensor, conditioning_input: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the Attention Layer module. + + Args: + directional_input (torch.Tensor): The directional input tensor. + conditioning_input (torch.Tensor): The conditioning input tensor. + + Returns: + torch.Tensor: The output tensor of the Attention Layer module. + """ + attn_output = self.mha( + directional_input, conditioning_input, conditioning_input + ) + out1 = self.norm1(attn_output + directional_input) + fc_output = self.fc(out1) + out2 = self.norm2(fc_output + out1) + return out2 + + +class Decoder(nn.Module): + def __init__( + self, + in_dim: int, + conditioning_input_dim: int, + hidden_features: int, + num_heads: int, + num_layers: int, + out_activation: Optional[nn.Module], + ): + """ + Decoder module. + + Args: + in_dim (int): The input dimension of the module. + conditioning_input_dim (int): The input dimension of the conditioning input. + hidden_features (int): The number of hidden features in the module. + num_heads (int): The number of heads to use in the attention mechanism. + num_layers (int): The number of layers in the module. + out_activation (nn.Module): The activation function to use on the output tensor. + """ + super().__init__() + self.residual_projection = nn.Linear( + in_dim, hidden_features + ) # projection for residual connection + self.layers = nn.ModuleList( + [ + AttentionLayer( + hidden_features, conditioning_input_dim, hidden_features, num_heads + ) + for i in range(num_layers) + ] + ) + self.fc = nn.Linear(hidden_features, 3) # 3 for RGB + self.out_activation = out_activation + + def forward( + self, x: torch.Tensor, conditioning_input: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the Decoder module. + + Args: + x (torch.Tensor): The input tensor. + conditioning_input (torch.Tensor): The conditioning input tensor. + + Returns: + torch.Tensor: The output tensor of the Decoder module. + """ + x = self.residual_projection(x) + for layer in self.layers: + x = layer(x, conditioning_input) + x = self.fc(x) + if self.out_activation is not None: + x = self.out_activation(x) + return x diff --git a/spar3d/models/illumination/reni/components/vn_layers.py b/spar3d/models/illumination/reni/components/vn_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8388de9a31325a6dd9822104b7d533dff2b6113d --- /dev/null +++ b/spar3d/models/illumination/reni/components/vn_layers.py @@ -0,0 +1,548 @@ +# MIT License + +# Copyright (c) 2022 Phil Wang + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""All code taken from https://github.com/lucidrains/VN-transformer""" + +from collections import namedtuple +from functools import wraps + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from packaging import version +from torch import einsum, nn + +# constants + +FlashAttentionConfig = namedtuple( + "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] +) + +# helpers + + +def exists(val): + return val is not None + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, flash=False, l2_dist=False): + super().__init__() + assert not ( + flash and l2_dist + ), "flash attention is not compatible with l2 distance" + self.l2_dist = l2_dist + + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not ( + flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once( + "A100 GPU detected, using flash attention if input tensor is on cuda" + ) + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once( + "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" + ) + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, _, is_cuda = ( + *q.shape, + k.shape[-2], + q.is_cuda, + ) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + scale = q.shape[-1] ** -0.5 + + if exists(mask) and mask.ndim != 4: + mask = rearrange(mask, "b j -> b 1 1 j") + + if self.flash: + return self.flash_attn(q, k, v, mask=mask) + + # similarity + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale + + # l2 distance + + if self.l2_dist: + # -cdist squared == (-q^2 + 2qk - k^2) + # so simply work off the qk above + q_squared = reduce(q**2, "b h i d -> b h i 1", "sum") + k_squared = reduce(k**2, "b h j d -> b h 1 j", "sum") + sim = sim * 2 - q_squared - k_squared + + # key padding mask + + if exists(mask): + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + return out + + +# helper + + +def exists(val): # noqa: F811 + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def inner_dot_product(x, y, *, dim=-1, keepdim=True): + return (x * y).sum(dim=dim, keepdim=keepdim) + + +# layernorm + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +# equivariant modules + + +class VNLinear(nn.Module): + def __init__(self, dim_in, dim_out, bias_epsilon=0.0): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim_out, dim_in)) + + self.bias = None + self.bias_epsilon = bias_epsilon + + # in this paper, they propose going for quasi-equivariance with a small bias, controllable with epsilon, which they claim lead to better stability and results + + if bias_epsilon > 0.0: + self.bias = nn.Parameter(torch.randn(dim_out)) + + def forward(self, x): + out = einsum("... i c, o i -> ... o c", x, self.weight) + + if exists(self.bias): + bias = F.normalize(self.bias, dim=-1) * self.bias_epsilon + out = out + rearrange(bias, "... -> ... 1") + + return out + + +class VNReLU(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.eps = eps + self.W = nn.Parameter(torch.randn(dim, dim)) + self.U = nn.Parameter(torch.randn(dim, dim)) + + def forward(self, x): + q = einsum("... i c, o i -> ... o c", x, self.W) + k = einsum("... i c, o i -> ... o c", x, self.U) + + qk = inner_dot_product(q, k) + + k_norm = k.norm(dim=-1, keepdim=True).clamp(min=self.eps) + q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k + + out = torch.where(qk >= 0.0, q, q_projected_on_k) + + return out + + +class VNAttention(nn.Module): + def __init__( + self, + dim, + dim_head=64, + heads=8, + dim_coor=3, + bias_epsilon=0.0, + l2_dist_attn=False, + flash=False, + num_latents=None, # setting this would enable perceiver-like cross attention from latents to sequence, with the latents derived from VNWeightedPool + ): + super().__init__() + assert not ( + l2_dist_attn and flash + ), "l2 distance attention is not compatible with flash attention" + + self.scale = (dim_coor * dim_head) ** -0.5 + dim_inner = dim_head * heads + self.heads = heads + + self.to_q_input = None + if exists(num_latents): + self.to_q_input = VNWeightedPool( + dim, num_pooled_tokens=num_latents, squeeze_out_pooled_dim=False + ) + + self.to_q = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon) + self.to_k = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon) + self.to_v = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon) + self.to_out = VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon) + + if l2_dist_attn and not exists(num_latents): + # tied queries and keys for l2 distance attention, and not perceiver-like attention + self.to_k = self.to_q + + self.attend = Attend(flash=flash, l2_dist=l2_dist_attn) + + def forward(self, x, mask=None): + """ + einstein notation + b - batch + n - sequence + h - heads + d - feature dimension (channels) + c - coordinate dimension (3 for 3d space) + i - source sequence dimension + j - target sequence dimension + """ + + c = x.shape[-1] + + if exists(self.to_q_input): + q_input = self.to_q_input(x, mask=mask) + else: + q_input = x + + q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) c -> b h n (d c)", h=self.heads), + (q, k, v), + ) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n (d c) -> b n (h d) c", c=c) + return self.to_out(out) + + +def VNFeedForward(dim, mult=4, bias_epsilon=0.0): + dim_inner = int(dim * mult) + return nn.Sequential( + VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon), + VNReLU(dim_inner), + VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon), + ) + + +class VNLayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.eps = eps + self.ln = LayerNorm(dim) + + def forward(self, x): + norms = x.norm(dim=-1) + x = x / rearrange(norms.clamp(min=self.eps), "... -> ... 1") + ln_out = self.ln(norms) + return x * rearrange(ln_out, "... -> ... 1") + + +class VNWeightedPool(nn.Module): + def __init__( + self, dim, dim_out=None, num_pooled_tokens=1, squeeze_out_pooled_dim=True + ): + super().__init__() + dim_out = default(dim_out, dim) + self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out)) + self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim + + def forward(self, x, mask=None): + if exists(mask): + mask = rearrange(mask, "b n -> b n 1 1") + x = x.masked_fill(~mask, 0.0) + numer = reduce(x, "b n d c -> b d c", "sum") + denom = mask.sum(dim=1) + mean_pooled = numer / denom.clamp(min=1e-6) + else: + mean_pooled = reduce(x, "b n d c -> b d c", "mean") + + out = einsum("b d c, m d e -> b m e c", mean_pooled, self.weight) + + if not self.squeeze_out_pooled_dim: + return out + + out = rearrange(out, "b 1 d c -> b d c") + return out + + +# equivariant VN transformer encoder + + +class VNTransformerEncoder(nn.Module): + def __init__( + self, + dim, + *, + depth, + dim_head=64, + heads=8, + dim_coor=3, + ff_mult=4, + final_norm=False, + bias_epsilon=0.0, + l2_dist_attn=False, + flash_attn=False, + ): + super().__init__() + self.dim = dim + self.dim_coor = dim_coor + + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + VNAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + bias_epsilon=bias_epsilon, + l2_dist_attn=l2_dist_attn, + flash=flash_attn, + ), + VNLayerNorm(dim), + VNFeedForward(dim=dim, mult=ff_mult, bias_epsilon=bias_epsilon), + VNLayerNorm(dim), + ] + ) + ) + + self.norm = VNLayerNorm(dim) if final_norm else nn.Identity() + + def forward(self, x, mask=None): + *_, d, c = x.shape + + assert ( + x.ndim == 4 and d == self.dim and c == self.dim_coor + ), "input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))" + + for attn, attn_post_ln, ff, ff_post_ln in self.layers: + x = attn_post_ln(attn(x, mask=mask)) + x + x = ff_post_ln(ff(x)) + x + + return self.norm(x) + + +# invariant layers + + +class VNInvariant(nn.Module): + def __init__( + self, + dim, + dim_coor=3, + ): + super().__init__() + self.mlp = nn.Sequential( + VNLinear(dim, dim_coor), VNReLU(dim_coor), Rearrange("... d e -> ... e d") + ) + + def forward(self, x): + return einsum("b n d i, b n i o -> b n o", x, self.mlp(x)) + + +# main class + + +class VNTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + num_tokens=None, + dim_feat=None, + dim_head=64, + heads=8, + dim_coor=3, + reduce_dim_out=True, + bias_epsilon=0.0, + l2_dist_attn=False, + flash_attn=False, + translation_equivariance=False, + translation_invariant=False, + ): + super().__init__() + self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None + + dim_feat = default(dim_feat, 0) + self.dim_feat = dim_feat + self.dim_coor_total = dim_coor + dim_feat + + assert (int(translation_equivariance) + int(translation_invariant)) <= 1 + self.translation_equivariance = translation_equivariance + self.translation_invariant = translation_invariant + + self.vn_proj_in = nn.Sequential( + Rearrange("... c -> ... 1 c"), VNLinear(1, dim, bias_epsilon=bias_epsilon) + ) + + self.encoder = VNTransformerEncoder( + dim=dim, + depth=depth, + dim_head=dim_head, + heads=heads, + bias_epsilon=bias_epsilon, + dim_coor=self.dim_coor_total, + l2_dist_attn=l2_dist_attn, + flash_attn=flash_attn, + ) + + if reduce_dim_out: + self.vn_proj_out = nn.Sequential( + VNLayerNorm(dim), + VNLinear(dim, 1, bias_epsilon=bias_epsilon), + Rearrange("... 1 c -> ... c"), + ) + else: + self.vn_proj_out = nn.Identity() + + def forward( + self, coors, *, feats=None, mask=None, return_concatted_coors_and_feats=False + ): + if self.translation_equivariance or self.translation_invariant: + coors_mean = reduce(coors, "... c -> c", "mean") + coors = coors - coors_mean + + x = coors # [batch, num_points, 3] + + if exists(feats): + if feats.dtype == torch.long: + assert exists( + self.token_emb + ), "num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices" + feats = self.token_emb(feats) + + assert ( + feats.shape[-1] == self.dim_feat + ), f"dim_feat should be set to {feats.shape[-1]}" + x = torch.cat((x, feats), dim=-1) # [batch, num_points, 3 + dim_feat] + + assert x.shape[-1] == self.dim_coor_total + + x = self.vn_proj_in(x) # [batch, num_points, hidden_dim, 3 + dim_feat] + x = self.encoder(x, mask=mask) # [batch, num_points, hidden_dim, 3 + dim_feat] + x = self.vn_proj_out(x) # [batch, num_points, 3 + dim_feat] + + coors_out, feats_out = ( + x[..., :3], + x[..., 3:], + ) # [batch, num_points, 3], [batch, num_points, dim_feat] + + if self.translation_equivariance: + coors_out = coors_out + coors_mean + + if not exists(feats): + return coors_out + + if return_concatted_coors_and_feats: + return torch.cat((coors_out, feats_out), dim=-1) + + return coors_out, feats_out diff --git a/spar3d/models/illumination/reni/env_map.py b/spar3d/models/illumination/reni/env_map.py new file mode 100644 index 0000000000000000000000000000000000000000..fa569df2fb5e60769ea29f2b3fb74f1a2b32cc63 --- /dev/null +++ b/spar3d/models/illumination/reni/env_map.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import torch +from jaxtyping import Float +from torch import Tensor + +from spar3d.models.utils import BaseModule + +from .field import RENIField + + +def _direction_from_coordinate( + coordinate: Float[Tensor, "*B 2"], +) -> Float[Tensor, "*B 3"]: + # OpenGL Convention + # +X Right + # +Y Up + # +Z Backward + + u, v = coordinate.unbind(-1) + theta = (2 * torch.pi * u) - torch.pi + phi = torch.pi * v + + dir = torch.stack( + [ + theta.sin() * phi.sin(), + phi.cos(), + -1 * theta.cos() * phi.sin(), + ], + -1, + ) + return dir + + +def _get_sample_coordinates( + resolution: List[int], device: Optional[torch.device] = None +) -> Float[Tensor, "H W 2"]: + return torch.stack( + torch.meshgrid( + (torch.arange(resolution[1], device=device) + 0.5) / resolution[1], + (torch.arange(resolution[0], device=device) + 0.5) / resolution[0], + indexing="xy", + ), + -1, + ) + + +class RENIEnvMap(BaseModule): + @dataclass + class Config(BaseModule.Config): + reni_config: dict = field(default_factory=dict) + resolution: int = 128 + + cfg: Config + + def configure(self): + self.field = RENIField(self.cfg.reni_config) + resolution = (self.cfg.resolution, self.cfg.resolution * 2) + sample_directions = _direction_from_coordinate( + _get_sample_coordinates(resolution) + ) + self.img_shape = sample_directions.shape[:-1] + + sample_directions_flat = sample_directions.view(-1, 3) + # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis + sample_directions_flat = torch.stack( + [ + sample_directions_flat[:, 0], + -sample_directions_flat[:, 2], + sample_directions_flat[:, 1], + ], + -1, + ) + self.sample_directions = torch.nn.Parameter( + sample_directions_flat, requires_grad=False + ) + + def forward( + self, + latent_codes: Float[Tensor, "B latent_dim 3"], + rotation: Optional[Float[Tensor, "B 3 3"]] = None, + scale: Optional[Float[Tensor, "B"]] = None, + ) -> Dict[str, Tensor]: + return { + k: v.view(latent_codes.shape[0], *self.img_shape, -1) + for k, v in self.field( + self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1), + latent_codes, + rotation=rotation, + scale=scale, + ).items() + } diff --git a/spar3d/models/illumination/reni/field.py b/spar3d/models/illumination/reni/field.py new file mode 100644 index 0000000000000000000000000000000000000000..e0232a11fc3d94aae81b8a4cf549769f0876b1e8 --- /dev/null +++ b/spar3d/models/illumination/reni/field.py @@ -0,0 +1,736 @@ +# Copyright 2023 The University of York. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified by Mark Boss + +"""RENI field""" + +import contextlib +from dataclasses import dataclass +from typing import Dict, Literal, Optional + +import torch +from einops.layers.torch import Rearrange +from jaxtyping import Float +from torch import Tensor, nn + +from spar3d.models.network import get_activation_module, trunc_exp +from spar3d.models.utils import BaseModule + +from .components.film_siren import FiLMSiren +from .components.siren import Siren +from .components.transformer_decoder import Decoder +from .components.vn_layers import VNInvariant, VNLinear + +# from nerfstudio.cameras.rays import RaySamples + + +def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor: + """Computes the expected value of sin(y) where y ~ N(x_means, x_vars) + + Args: + x_means: Mean values. + x_vars: Variance of values. + + Returns: + torch.Tensor: The expected value of sin. + """ + + return torch.exp(-0.5 * x_vars) * torch.sin(x_means) + + +class NeRFEncoding(torch.nn.Module): + """Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided. + Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. + + Args: + in_dim: Input dimension of tensor + num_frequencies: Number of encoded frequencies per axis + min_freq_exp: Minimum frequency exponent + max_freq_exp: Maximum frequency exponent + include_input: Append the input coordinate to the encoding + """ + + def __init__( + self, + in_dim: int, + num_frequencies: int, + min_freq_exp: float, + max_freq_exp: float, + include_input: bool = False, + off_axis: bool = False, + ) -> None: + super().__init__() + + self.in_dim = in_dim + self.num_frequencies = num_frequencies + self.min_freq = min_freq_exp + self.max_freq = max_freq_exp + self.include_input = include_input + + self.off_axis = off_axis + + self.P = torch.tensor( + [ + [0.8506508, 0, 0.5257311], + [0.809017, 0.5, 0.309017], + [0.5257311, 0.8506508, 0], + [1, 0, 0], + [0.809017, 0.5, -0.309017], + [0.8506508, 0, -0.5257311], + [0.309017, 0.809017, -0.5], + [0, 0.5257311, -0.8506508], + [0.5, 0.309017, -0.809017], + [0, 1, 0], + [-0.5257311, 0.8506508, 0], + [-0.309017, 0.809017, -0.5], + [0, 0.5257311, 0.8506508], + [-0.309017, 0.809017, 0.5], + [0.309017, 0.809017, 0.5], + [0.5, 0.309017, 0.809017], + [0.5, -0.309017, 0.809017], + [0, 0, 1], + [-0.5, 0.309017, 0.809017], + [-0.809017, 0.5, 0.309017], + [-0.809017, 0.5, -0.309017], + ] + ).T + + def get_out_dim(self) -> int: + if self.in_dim is None: + raise ValueError("Input dimension has not been set") + out_dim = self.in_dim * self.num_frequencies * 2 + + if self.off_axis: + out_dim = self.P.shape[1] * self.num_frequencies * 2 + + if self.include_input: + out_dim += self.in_dim + return out_dim + + def forward( + self, + in_tensor: Float[Tensor, "*b input_dim"], + covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None, + ) -> Float[Tensor, "*b output_dim"]: + """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed + in mip-NeRF. + + Args: + in_tensor: For best performance, the input tensor should be between 0 and 1. + covs: Covariances of input points. + Returns: + Output values will be between -1 and 1 + """ + # TODO check scaling here but just comment it for now + # in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] + freqs = 2 ** torch.linspace( + self.min_freq, self.max_freq, self.num_frequencies + ).to(in_tensor.device) + # freqs = 2 ** ( + # torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq + # ).to(in_tensor.device) + # freqs = 2 ** ( + # torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq + # ) + + if self.off_axis: + scaled_inputs = ( + torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs + ) + else: + scaled_inputs = ( + in_tensor[..., None] * freqs + ) # [..., "input_dim", "num_scales"] + scaled_inputs = scaled_inputs.view( + *scaled_inputs.shape[:-2], -1 + ) # [..., "input_dim" * "num_scales"] + + if covs is None: + encoded_inputs = torch.sin( + torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1) + ) + else: + input_var = ( + torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None] + * freqs[None, :] ** 2 + ) + input_var = input_var.reshape((*input_var.shape[:-2], -1)) + encoded_inputs = expected_sin( + torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), + torch.cat(2 * [input_var], dim=-1), + ) + + if self.include_input: + encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1) + return encoded_inputs + + +class RENIField(BaseModule): + @dataclass + class Config(BaseModule.Config): + """Configuration for model instantiation""" + + fixed_decoder: bool = False + """Whether to fix the decoder weights""" + equivariance: str = "SO2" + """Type of equivariance to use: None, SO2, SO3""" + axis_of_invariance: str = "y" + """Which axis should SO2 equivariance be invariant to: x, y, z""" + invariant_function: str = "GramMatrix" + """Type of invariant function to use: GramMatrix, VN""" + conditioning: str = "Concat" + """Type of conditioning to use: FiLM, Concat, Attention""" + positional_encoding: str = "NeRF" + """Type of positional encoding to use. Currently only NeRF is supported""" + encoded_input: str = "Directions" + """Type of input to encode: None, Directions, Conditioning, Both""" + latent_dim: int = 36 + """Dimensionality of latent code, N for a latent code size of (N x 3)""" + hidden_layers: int = 3 + """Number of hidden layers""" + hidden_features: int = 128 + """Number of hidden features""" + mapping_layers: int = 3 + """Number of mapping layers""" + mapping_features: int = 128 + """Number of mapping features""" + num_attention_heads: int = 8 + """Number of attention heads""" + num_attention_layers: int = 3 + """Number of attention layers""" + out_features: int = 3 # RGB + """Number of output features""" + last_layer_linear: bool = False + """Whether to use a linear layer as the last layer""" + output_activation: str = "exp" + """Activation function for output layer: sigmoid, tanh, relu, exp, None""" + first_omega_0: float = 30.0 + """Omega_0 for first layer""" + hidden_omega_0: float = 30.0 + """Omega_0 for hidden layers""" + fixed_decoder: bool = False + """Whether to fix the decoder weights""" + old_implementation: bool = False + """Whether to match implementation of old RENI, when using old checkpoints""" + + cfg: Config + + def configure(self): + self.equivariance = self.cfg.equivariance + self.conditioning = self.cfg.conditioning + self.latent_dim = self.cfg.latent_dim + self.hidden_layers = self.cfg.hidden_layers + self.hidden_features = self.cfg.hidden_features + self.mapping_layers = self.cfg.mapping_layers + self.mapping_features = self.cfg.mapping_features + self.out_features = self.cfg.out_features + self.last_layer_linear = self.cfg.last_layer_linear + self.output_activation = self.cfg.output_activation + self.first_omega_0 = self.cfg.first_omega_0 + self.hidden_omega_0 = self.cfg.hidden_omega_0 + self.old_implementation = self.cfg.old_implementation + self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance) + + self.fixed_decoder = self.cfg.fixed_decoder + if self.cfg.invariant_function == "GramMatrix": + self.invariant_function = self.gram_matrix_invariance + else: + self.vn_proj_in = nn.Sequential( + Rearrange("... c -> ... 1 c"), + VNLinear(dim_in=1, dim_out=1, bias_epsilon=0), + ) + dim_coor = 2 if self.cfg.equivariance == "SO2" else 3 + self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor) + self.invariant_function = self.vn_invariance + + self.network = self.setup_network() + + if self.fixed_decoder: + for param in self.network.parameters(): + param.requires_grad = False + + if self.cfg.invariant_function == "VN": + for param in self.vn_proj_in.parameters(): + param.requires_grad = False + for param in self.vn_invar.parameters(): + param.requires_grad = False + + @contextlib.contextmanager + def hold_decoder_fixed(self): + """Context manager to fix the decoder weights + + Example usage: + ``` + with instance_of_RENIField.hold_decoder_fixed(): + # do stuff + ``` + """ + prev_state_network = { + name: p.requires_grad for name, p in self.network.named_parameters() + } + for param in self.network.parameters(): + param.requires_grad = False + if self.cfg.invariant_function == "VN": + prev_state_proj_in = { + k: p.requires_grad for k, p in self.vn_proj_in.named_parameters() + } + prev_state_invar = { + k: p.requires_grad for k, p in self.vn_invar.named_parameters() + } + for param in self.vn_proj_in.parameters(): + param.requires_grad = False + for param in self.vn_invar.parameters(): + param.requires_grad = False + + prev_decoder_state = self.fixed_decoder + self.fixed_decoder = True + try: + yield + finally: + # Restore the previous requires_grad state + for name, param in self.network.named_parameters(): + param.requires_grad = prev_state_network[name] + if self.cfg.invariant_function == "VN": + for name, param in self.vn_proj_in.named_parameters(): + param.requires_grad_(prev_state_proj_in[name]) + for name, param in self.vn_invar.named_parameters(): + param.requires_grad_(prev_state_invar[name]) + self.fixed_decoder = prev_decoder_state + + def vn_invariance( + self, + Z: Float[Tensor, "B latent_dim 3"], + D: Float[Tensor, "B num_rays 3"], + equivariance: Literal["None", "SO2", "SO3"] = "SO2", + axis_of_invariance: int = 1, + ): + """Generates a batched invariant representation from latent code Z and direction coordinates D. + + Args: + Z: [B, latent_dim, 3] - Latent code. + D: [B num_rays, 3] - Direction coordinates. + equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'. + axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). + + Returns: + Tuple[Tensor, Tensor]: directional_input, conditioning_input + """ + assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." + other_axes = [i for i in range(3) if i != axis_of_invariance] + + B, latent_dim, _ = Z.shape + _, num_rays, _ = D.shape + + if equivariance == "None": + # get inner product between latent code and direction coordinates + innerprod = torch.sum( + Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 + ) # [B, num_rays, latent_dim] + z_input = ( + Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) + ) # [B, num_rays, latent_dim * 3] + return innerprod, z_input + + if equivariance == "SO2": + z_other = torch.stack( + (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 + ) # [B, latent_dim, 2] + d_other = torch.stack( + (D[..., other_axes[0]], D[..., other_axes[1]]), -1 + ).unsqueeze(2) # [B, num_rays, 1, 2] + d_other = d_other.expand( + B, num_rays, latent_dim, 2 + ) # [B, num_rays, latent_dim, 2] + + z_other_emb = self.vn_proj_in(z_other) # [B, latent_dim, 1, 2] + z_other_invar = self.vn_invar(z_other_emb) # [B, latent_dim, 2] + + # Get invariant component of Z along the axis of invariance + z_invar = Z[..., axis_of_invariance].unsqueeze(-1) # [B, latent_dim, 1] + + # Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance. + # This encodes the rotational information. This is rotation-equivariant to rotations of either Z + # or D and is invariant to rotations of both Z and D. + innerprod = (z_other.unsqueeze(1) * d_other).sum( + dim=-1 + ) # [B, num_rays, latent_dim] + + # Compute norm along the axes orthogonal to the axis of invariance + d_other_norm = torch.sqrt( + D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 + ).unsqueeze(-1) # [B num_rays, 1] + + # Get invariant component of D along the axis of invariance + d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1] + + directional_input = torch.cat( + (innerprod, d_invar, d_other_norm), -1 + ) # [B, num_rays, latent_dim + 2] + conditioning_input = ( + torch.cat((z_other_invar, z_invar), dim=-1) + .flatten(1) + .unsqueeze(1) + .expand(B, num_rays, latent_dim * 3) + ) # [B, num_rays, latent_dim * 3] + + return directional_input, conditioning_input + + if equivariance == "SO3": + z = self.vn_proj_in(Z) # [B, latent_dim, 1, 3] + z_invar = self.vn_invar(z) # [B, latent_dim, 3] + conditioning_input = ( + z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim) + ) # [B, num_rays, latent_dim * 3] + # D [B, num_rays, 3] -> [B, num_rays, 1, 3] + # Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3] + innerprod = torch.sum( + Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 + ) # [B, num_rays, latent_dim] + return innerprod, conditioning_input + + def gram_matrix_invariance( + self, + Z: Float[Tensor, "B latent_dim 3"], + D: Float[Tensor, "B num_rays 3"], + equivariance: Literal["None", "SO2", "SO3"] = "SO2", + axis_of_invariance: int = 1, + ): + """Generates an invariant representation from latent code Z and direction coordinates D. + + Args: + Z (torch.Tensor): Latent code (B x latent_dim x 3) + D (torch.Tensor): Direction coordinates (B x num_rays x 3) + equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3' + axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). + Default is 1 (y-axis). + Returns: + torch.Tensor: Invariant representation + """ + assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." + other_axes = [i for i in range(3) if i != axis_of_invariance] + + B, latent_dim, _ = Z.shape + _, num_rays, _ = D.shape + + if equivariance == "None": + # get inner product between latent code and direction coordinates + innerprod = torch.sum( + Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 + ) # [B, num_rays, latent_dim] + z_input = ( + Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) + ) # [B, num_rays, latent_dim * 3] + return innerprod, z_input + + if equivariance == "SO2": + # Select components along axes orthogonal to the axis of invariance + z_other = torch.stack( + (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 + ) # [B, latent_dim, 2] + d_other = torch.stack( + (D[..., other_axes[0]], D[..., other_axes[1]]), -1 + ).unsqueeze(2) # [B, num_rays, 1, 2] + d_other = d_other.expand( + B, num_rays, latent_dim, 2 + ) # size becomes [B, num_rays, latent_dim, 2] + + # Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim + G = torch.bmm(z_other, torch.transpose(z_other, 1, 2)) + + # Flatten G to be size B x latent_dim^2 + z_other_invar = G.flatten(start_dim=1) + + # Get invariant component of Z along the axis of invariance + z_invar = Z[..., axis_of_invariance] # [B, latent_dim] + + # Innerprod is size num_rays x latent_dim + innerprod = (z_other.unsqueeze(1) * d_other).sum( + dim=-1 + ) # [B, num_rays, latent_dim] + + # Compute norm along the axes orthogonal to the axis of invariance + d_other_norm = torch.sqrt( + D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 + ).unsqueeze(-1) # [B, num_rays, 1] + + # Get invariant component of D along the axis of invariance + d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1] + + if not self.old_implementation: + directional_input = torch.cat( + (innerprod, d_invar, d_other_norm), -1 + ) # [B, num_rays, latent_dim + 2] + conditioning_input = ( + torch.cat((z_other_invar, z_invar), -1) + .unsqueeze(1) + .expand(B, num_rays, latent_dim * 3) + ) # [B, num_rays, latent_dim^2 + latent_dim] + else: + # this is matching the previous implementation of RENI, needed if using old checkpoints + z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1) + z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1) + return torch.cat( + (innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1 + ) + + return directional_input, conditioning_input + + if equivariance == "SO3": + G = Z @ torch.transpose(Z, 1, 2) # [B, latent_dim, latent_dim] + innerprod = torch.sum( + Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 + ) # [B, num_rays, latent_dim] + z_invar = ( + G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1) + ) # [B, num_rays, latent_dim^2] + return innerprod, z_invar + + def setup_network(self): + """Sets up the network architecture""" + base_input_dims = { + "VN": { + "None": { + "direction": self.latent_dim, + "conditioning": self.latent_dim * 3, + }, + "SO2": { + "direction": self.latent_dim + 2, + "conditioning": self.latent_dim * 3, + }, + "SO3": { + "direction": self.latent_dim, + "conditioning": self.latent_dim * 3, + }, + }, + "GramMatrix": { + "None": { + "direction": self.latent_dim, + "conditioning": self.latent_dim * 3, + }, + "SO2": { + "direction": self.latent_dim + 2, + "conditioning": self.latent_dim**2 + self.latent_dim, + }, + "SO3": { + "direction": self.latent_dim, + "conditioning": self.latent_dim**2, + }, + }, + } + + # Extract the necessary input dimensions + input_types = ["direction", "conditioning"] + input_dims = { + key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][ + key + ] + for key in input_types + } + + # Helper function to create NeRF encoding + def create_nerf_encoding(in_dim): + return NeRFEncoding( + in_dim=in_dim, + num_frequencies=2, + min_freq_exp=0.0, + max_freq_exp=2.0, + include_input=True, + ) + + # Dictionary-based encoding setup + encoding_setup = { + "None": [], + "Conditioning": ["conditioning"], + "Directions": ["direction"], + "Both": ["direction", "conditioning"], + } + + # Setting up the required encodings + for input_type in encoding_setup.get(self.cfg.encoded_input, []): + # create self.{input_type}_encoding and update input_dims + setattr( + self, + f"{input_type}_encoding", + create_nerf_encoding(input_dims[input_type]), + ) + input_dims[input_type] = getattr( + self, f"{input_type}_encoding" + ).get_out_dim() + + output_activation = get_activation_module(self.cfg.output_activation) + + network = None + if self.conditioning == "Concat": + network = Siren( + in_dim=input_dims["direction"] + input_dims["conditioning"], + hidden_layers=self.hidden_layers, + hidden_features=self.hidden_features, + out_dim=self.out_features, + outermost_linear=self.last_layer_linear, + first_omega_0=self.first_omega_0, + hidden_omega_0=self.hidden_omega_0, + out_activation=output_activation, + ) + elif self.conditioning == "FiLM": + network = FiLMSiren( + in_dim=input_dims["direction"], + hidden_layers=self.hidden_layers, + hidden_features=self.hidden_features, + mapping_network_in_dim=input_dims["conditioning"], + mapping_network_layers=self.mapping_layers, + mapping_network_features=self.mapping_features, + out_dim=self.out_features, + outermost_linear=True, + out_activation=output_activation, + ) + elif self.conditioning == "Attention": + # transformer where K, V is from conditioning input and Q is from pos encoded directional input + network = Decoder( + in_dim=input_dims["direction"], + conditioning_input_dim=input_dims["conditioning"], + hidden_features=self.cfg.hidden_features, + num_heads=self.cfg.num_attention_heads, + num_layers=self.cfg.num_attention_layers, + out_activation=output_activation, + ) + assert network is not None, "unknown conditioning type" + return network + + def apply_positional_encoding(self, directional_input, conditioning_input): + # conditioning on just invariant directional input + if self.cfg.encoded_input == "Conditioning": + conditioning_input = self.conditioning_encoding( + conditioning_input + ) # [num_rays, embedding_dim] + elif self.cfg.encoded_input == "Directions": + directional_input = self.direction_encoding( + directional_input + ) # [num_rays, embedding_dim] + elif self.cfg.encoded_input == "Both": + directional_input = self.direction_encoding(directional_input) + conditioning_input = self.conditioning_encoding(conditioning_input) + + return directional_input, conditioning_input + + def get_outputs( + self, + rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore + latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore + rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore + scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore + ) -> Dict[str, Tensor]: + """Returns the outputs of the field. + + Args: + ray_samples: [batch_size num_rays 3] + latent_codes: [batch_size, latent_dim, 3] + rotation: [batch_size, 3, 3] + scale: [batch_size] + """ + if rotation is not None: + if len(rotation.shape) == 3: # [batch_size, 3, 3] + # Expand latent_codes to match [batch_size, latent_dim, 3] + latent_codes = torch.einsum( + "bik,blk->bli", + rotation, + latent_codes, + ) + else: + raise NotImplementedError( + "Unsupported rotation shape. Expected [batch_size, 3, 3]." + ) + + B, num_rays, _ = rays_d.shape + _, latent_dim, _ = latent_codes.shape + + if not self.old_implementation: + directional_input, conditioning_input = self.invariant_function( + latent_codes, + rays_d, + equivariance=self.equivariance, + axis_of_invariance=self.axis_of_invariance, + ) # [B, num_rays, 3] + + if self.cfg.positional_encoding == "NeRF": + directional_input, conditioning_input = self.apply_positional_encoding( + directional_input, conditioning_input + ) + + if self.conditioning == "Concat": + model_outputs = self.network( + torch.cat((directional_input, conditioning_input), dim=-1).reshape( + B * num_rays, -1 + ) + ).view(B, num_rays, 3) # returns -> [B num_rays, 3] + elif self.conditioning == "FiLM": + model_outputs = self.network( + directional_input.reshape(B * num_rays, -1), + conditioning_input.reshape(B * num_rays, -1), + ).view(B, num_rays, 3) # returns -> [B num_rays, 3] + elif self.conditioning == "Attention": + model_outputs = self.network( + directional_input.reshape(B * num_rays, -1), + conditioning_input.reshape(B * num_rays, -1), + ).view(B, num_rays, 3) # returns -> [B num_rays, 3] + else: + # in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions + directions = torch.stack( + (rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1 + ) + model_input = self.invariant_function( + latent_codes, + directions, + equivariance=self.equivariance, + axis_of_invariance=self.axis_of_invariance, + ) # [B, num_rays, 3] + + model_outputs = self.network(model_input.view(B * num_rays, -1)).view( + B, num_rays, 3 + ) + + outputs = {} + + if scale is not None: + scale = trunc_exp(scale) # [num_rays] exp to ensure positive + model_outputs = model_outputs * scale.view(-1, 1, 1) # [num_rays, 3] + + outputs["rgb"] = model_outputs + + return outputs + + def forward( + self, + rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore + latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore + rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore + scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore + ) -> Dict[str, Tensor]: + """Evaluates spherical field for a given ray bundle and rotation. + + Args: + ray_samples: [B num_rays 3] + latent_codes: [B, num_rays, latent_dim, 3] + rotation: [batch_size, 3, 3] + scale: [batch_size] + + Returns: + Dict[str, Tensor]: A dictionary containing the outputs of the field. + """ + return self.get_outputs( + rays_d=rays_d, + latent_codes=latent_codes, + rotation=rotation, + scale=scale, + ) diff --git a/spar3d/models/image_estimator/clip_based_estimator.py b/spar3d/models/image_estimator/clip_based_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..54f3a5e44b46cf9749adbb142aa381b7489269e3 --- /dev/null +++ b/spar3d/models/image_estimator/clip_based_estimator.py @@ -0,0 +1,184 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional + +import alpha_clip +import torch +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor +from torchvision.transforms import Normalize + +from spar3d.models.network import get_activation +from spar3d.models.utils import BaseModule + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + + +@dataclass +class HeadSpec: + name: str + out_channels: int + n_hidden_layers: int + output_activation: Optional[str] = None + output_bias: float = 0.0 + add_to_decoder_features: bool = False + shape: Optional[list[int]] = None + distribution_eval: str = "sample" + + +class ClipBasedHeadEstimator(BaseModule): + @dataclass + class Config(BaseModule.Config): + model: str = "ViT-L/14@336px" + + distribution: str = "beta" + + # ["mean", "mode", "sample", "sample_mean"] + distribution_eval: str = "mode" + + activation: str = "relu" + hidden_features: int = 512 + heads: List[HeadSpec] = field(default_factory=lambda: []) + + cfg: Config + + def configure(self): + self.model, _ = alpha_clip.load( + self.cfg.model, + ) # change to your own ckpt path + self.model.eval() + + if not hasattr(self.model.visual, "input_resolution"): + self.img_size = 224 + else: + self.img_size = self.model.visual.input_resolution + # Check if img_size is subscribable and pick the first element + if hasattr(self.img_size, "__getitem__"): + self.img_size = self.img_size[0] + + # Do not add the weights in self.model to the optimizer + for param in self.model.parameters(): + param.requires_grad = False + + assert len(self.cfg.heads) > 0 + heads = {} + for head in self.cfg.heads: + head_layers = [] + in_feature = self.model.visual.output_dim + + for i in range(head.n_hidden_layers): + head_layers += [ + nn.Linear( + in_feature if i == 0 else self.cfg.hidden_features, + self.cfg.hidden_features, + ), + self.make_activation(self.cfg.activation), + ] + + head_layers = [nn.Sequential(*head_layers)] + head_layers += [ + nn.Sequential( + nn.Linear( + self.cfg.hidden_features, + self.cfg.hidden_features, + ), + self.make_activation(self.cfg.activation), + nn.Linear(self.cfg.hidden_features, 1), + ) + for _ in range(2) + ] + heads[head.name] = nn.ModuleList(head_layers) + self.heads = nn.ModuleDict(heads) + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + def forward( + self, + cond_image: Float[Tensor, "B 1 H W 4"], + sample: bool = True, + ) -> dict[str, Any]: + # Run the model + # Resize cond_image to 224 + cond_image = cond_image.flatten(0, 1) + cond_image = nn.functional.interpolate( + cond_image.permute(0, 3, 1, 2), + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + mask = cond_image[:, 3:4] + cond_image = cond_image[:, :3] * mask + cond_image = Normalize( + mean=OPENAI_DATASET_MEAN, + std=OPENAI_DATASET_STD, + )(cond_image) + mask = Normalize(0.5, 0.26)(mask).half() + image_features = self.model.visual(cond_image.half(), mask).float() + + # Run the heads + outputs = {} + + for head_dict in self.cfg.heads: + head_name = head_dict.name + shared_head, d1_h, d2_h = self.heads[head_name] + shared_features = shared_head(image_features) + d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]] + if self.cfg.distribution == "normal": + mean = d1 + var = d2 + if mean.shape[-1] == 1: + outputs[head_name] = torch.distributions.Normal( + mean + head_dict.output_bias, + torch.nn.functional.softplus(var), + ) + else: + outputs[head_name] = torch.distributions.MultivariateNormal( + mean + head_dict.output_bias, + torch.nn.functional.softplus(var).diag_embed(), + ) + elif self.cfg.distribution == "beta": + outputs[head_name] = torch.distributions.Beta( + torch.nn.functional.softplus(d1 + head_dict.output_bias), + torch.nn.functional.softplus(d2 + head_dict.output_bias), + ) + else: + raise NotImplementedError + + if sample: + for head_dict in self.cfg.heads: + head_name = head_dict.name + dist = outputs[head_name] + + if head_dict.distribution_eval == "mean": + out = dist.mean + elif head_dict.distribution_eval == "mode": + out = dist.mode + elif head_dict.distribution_eval == "sample_mean": + out = dist.sample([10]).mean(-1) + else: + # use rsample if gradient is needed + out = dist.rsample() if self.training else dist.sample() + + outputs[head_name] = get_activation(head_dict.output_activation)(out) + outputs[f"{head_name}_dist"] = dist + + for head in self.cfg.heads: + if head.shape: + if not sample: + raise ValueError( + "Cannot reshape non-sampled probabilisitic outputs" + ) + outputs[head.name] = outputs[head.name].reshape(*head.shape) + + if head.add_to_decoder_features: + outputs[f"decoder_{head.name}"] = outputs[head.name] + del outputs[head.name] + + return outputs diff --git a/spar3d/models/isosurface.py b/spar3d/models/isosurface.py new file mode 100644 index 0000000000000000000000000000000000000000..aad0d345cabef10f6f0613767fad789943368dd2 --- /dev/null +++ b/spar3d/models/isosurface.py @@ -0,0 +1,229 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from jaxtyping import Float, Integer +from torch import Tensor + +from .mesh import Mesh + + +class IsosurfaceHelper(nn.Module): + points_range: Tuple[float, float] = (0, 1) + + @property + def grid_vertices(self) -> Float[Tensor, "N 3"]: + raise NotImplementedError + + @property + def requires_instance_per_batch(self) -> bool: + return False + + +class MarchingTetrahedraHelper(IsosurfaceHelper): + def __init__(self, resolution: int, tets_path: str): + super().__init__() + self.resolution = resolution + self.tets_path = tets_path + + self.triangle_table: Float[Tensor, "..."] + self.register_buffer( + "triangle_table", + torch.as_tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + ], + dtype=torch.long, + ), + persistent=False, + ) + self.num_triangles_table: Integer[Tensor, "..."] + self.register_buffer( + "num_triangles_table", + torch.as_tensor( + [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long + ), + persistent=False, + ) + self.base_tet_edges: Integer[Tensor, "..."] + self.register_buffer( + "base_tet_edges", + torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), + persistent=False, + ) + + tets = np.load(self.tets_path) + self._grid_vertices: Float[Tensor, "..."] + self.register_buffer( + "_grid_vertices", + torch.from_numpy(tets["vertices"]).float(), + persistent=False, + ) + self.indices: Integer[Tensor, "..."] + self.register_buffer( + "indices", torch.from_numpy(tets["indices"]).long(), persistent=False + ) + + self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None + + center_indices, boundary_indices = self.get_center_boundary_index( + self._grid_vertices + ) + self.center_indices: Integer[Tensor, "..."] + self.register_buffer("center_indices", center_indices, persistent=False) + self.boundary_indices: Integer[Tensor, "..."] + self.register_buffer("boundary_indices", boundary_indices, persistent=False) + + def get_center_boundary_index(self, verts): + magn = torch.sum(verts**2, dim=-1) + + center_idx = torch.argmin(magn) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) + + def normalize_grid_deformation( + self, grid_vertex_offsets: Float[Tensor, "Nv 3"] + ) -> Float[Tensor, "Nv 3"]: + return ( + (self.points_range[1] - self.points_range[0]) + / self.resolution # half tet size is approximately 1 / self.resolution + * torch.tanh(grid_vertex_offsets) + ) # FIXME: hard-coded activation + + @property + def grid_vertices(self) -> Float[Tensor, "Nv 3"]: + return self._grid_vertices + + @property + def all_edges(self) -> Integer[Tensor, "Ne 2"]: + if self._all_edges is None: + # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) + edges = torch.tensor( + [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], + dtype=torch.long, + device=self.indices.device, + ) + _all_edges = self.indices[:, edges].reshape(-1, 2) + _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] + _all_edges = torch.unique(_all_edges_sorted, dim=0) + self._all_edges = _all_edges + return self._all_edges + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + + return torch.stack([a, b], -1) + + def _forward(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = ( + torch.ones( + (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device + ) + * -1 + ) + mapping[mask_edges] = torch.arange( + mask_edges.sum(), dtype=torch.long, device=pos_nx3.device + ) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], + ).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], + ).reshape(-1, 3), + ), + dim=0, + ) + + return verts, faces + + def forward( + self, + level: Float[Tensor, "N3 1"], + deformation: Optional[Float[Tensor, "N3 3"]] = None, + ) -> Mesh: + if deformation is not None: + grid_vertices = self.grid_vertices + self.normalize_grid_deformation( + deformation + ) + else: + grid_vertices = self.grid_vertices + + v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) + + mesh = Mesh( + v_pos=v_pos, + t_pos_idx=t_pos_idx, + # extras + grid_vertices=grid_vertices, + tet_edges=self.all_edges, + grid_level=level, + grid_deformation=deformation, + ) + + return mesh diff --git a/spar3d/models/mesh.py b/spar3d/models/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..1612d5ebbf742738729c91285034185d8e2641d5 --- /dev/null +++ b/spar3d/models/mesh.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import math +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn.functional as F +import trimesh +from jaxtyping import Float, Integer +from torch import Tensor + +from spar3d.models.utils import dot + +try: + from uv_unwrapper import Unwrapper +except ImportError: + import logging + + logging.warning( + "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`" + ) + # Exit early to avoid further errors + raise ImportError("uv_unwrapper not found") + +try: + import gpytoolbox + + TRIANGLE_REMESH_AVAILABLE = True +except ImportError: + TRIANGLE_REMESH_AVAILABLE = False + import logging + + logging.warning( + "Could not import gpytoolbox. Triangle remeshing functionality will be disabled. " + "Install via `pip install gpytoolbox`" + ) + +try: + import pynim + + QUAD_REMESH_AVAILABLE = True +except ImportError: + QUAD_REMESH_AVAILABLE = False + import logging + + logging.warning( + "Could not import pynim. Quad remeshing functionality will be disabled. " + "Install via `pip install git+https://github.com/vork/PyNanoInstantMeshes.git@v0.0.3`" + ) + + +class Mesh: + def __init__( + self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs + ) -> None: + self.v_pos: Float[Tensor, "Nv 3"] = v_pos + self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx + self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None + self._edges: Optional[Integer[Tensor, "Ne 2"]] = None + self.extras: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.add_extra(k, v) + + self.unwrapper = Unwrapper() + + def add_extra(self, k, v) -> None: + self.extras[k] = v + + @property + def requires_grad(self): + return self.v_pos.requires_grad + + @property + def v_nrm(self): + if self._v_nrm is None: + self._v_nrm = self._compute_vertex_normal() + return self._v_nrm + + @property + def v_tng(self): + if self._v_tng is None: + self._v_tng = self._compute_vertex_tangent() + return self._v_tng + + @property + def v_tex(self): + if self._v_tex is None: + self.unwrap_uv() + return self._v_tex + + @property + def edges(self): + if self._edges is None: + self._edges = self._compute_edges() + return self._edges + + def _compute_vertex_normal(self): + i0 = self.t_pos_idx[:, 0] + i1 = self.t_pos_idx[:, 1] + i2 = self.t_pos_idx[:, 2] + + v0 = self.v_pos[i0, :] + v1 = self.v_pos[i1, :] + v2 = self.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(self.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + def _compute_vertex_tangent(self): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0, 3): + pos[i] = self.v_pos[self.t_pos_idx[:, i]] + tex[i] = self.v_tex[self.t_pos_idx[:, i]] + # t_nrm_idx is always the same as t_pos_idx + vn_idx[i] = self.t_pos_idx[:, i] + + tangents = torch.zeros_like(self.v_nrm) + tansum = torch.zeros_like(self.v_nrm) + + # Compute tangent space for each triangle + duv1 = tex[1] - tex[0] + duv2 = tex[2] - tex[0] + dpos1 = pos[1] - pos[0] + dpos2 = pos[2] - pos[0] + + tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] + + denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] + + # Avoid division by zero for degenerated texture coordinates + denom_safe = denom.clip(1e-6) + tang = tng_nom / denom_safe + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_( + 0, idx, torch.ones_like(tang) + ) # tansum[n_i] = tansum[n_i] + 1 + # Also normalize it. Here we do not normalize the individual triangles first so larger area + # triangles influence the tangent space more + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = F.normalize(tangents, dim=1) + tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return tangents + + def quad_remesh( + self, + quad_vertex_count: int = -1, + quad_rosy: int = 4, + quad_crease_angle: float = -1.0, + quad_smooth_iter: int = 2, + quad_align_to_boundaries: bool = False, + ) -> Mesh: + if not QUAD_REMESH_AVAILABLE: + raise ImportError("Quad remeshing requires pynim to be installed") + if quad_vertex_count < 0: + quad_vertex_count = self.v_pos.shape[0] + v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) + t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32) + + new_vert, new_faces = pynim.remesh( + v_pos, + t_pos_idx, + quad_vertex_count // 4, + rosy=quad_rosy, + posy=4, + creaseAngle=quad_crease_angle, + align_to_boundaries=quad_align_to_boundaries, + smooth_iter=quad_smooth_iter, + deterministic=False, + ) + + # Briefly load in trimesh + mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32)) + + v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous() + t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous() + + # Create new mesh + return Mesh(v_pos, t_pos_idx) + + def triangle_remesh( + self, + triangle_average_edge_length_multiplier: Optional[float] = None, + triangle_remesh_steps: int = 10, + triangle_vertex_count=-1, + ): + if not TRIANGLE_REMESH_AVAILABLE: + raise ImportError("Triangle remeshing requires gpytoolbox to be installed") + if triangle_vertex_count > 0: + reduction = triangle_vertex_count / self.v_pos.shape[0] + print("Triangle reduction:", reduction) + v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) + t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) + if reduction > 1.0: + subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2))) + print("Subdivide iters:", subdivide_iters) + v_pos, t_pos_idx = gpytoolbox.subdivide( + v_pos, + t_pos_idx, + iters=subdivide_iters, + ) + reduction = triangle_vertex_count / v_pos.shape[0] + + # Simplify + points_out, faces_out, _, _ = gpytoolbox.decimate( + v_pos, + t_pos_idx, + face_ratio=reduction, + ) + + # Convert back to torch + self.v_pos = torch.from_numpy(points_out).to(self.v_pos) + self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx) + self._edges = None + triangle_average_edge_length_multiplier = None + + edges = self.edges + if triangle_average_edge_length_multiplier is None: + h = None + else: + h = float( + torch.linalg.norm( + self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1 + ) + .mean() + .item() + * triangle_average_edge_length_multiplier + ) + + # Convert to numpy + v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64) + t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) + + # Remesh + v_remesh, f_remesh = gpytoolbox.remesh_botsch( + v_pos, + t_pos_idx, + triangle_remesh_steps, + h, + ) + + # Convert back to torch + v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous() + t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous() + + # Create new mesh + return Mesh(v_pos, t_pos_idx) + + @torch.no_grad() + def unwrap_uv( + self, + island_padding: float = 0.02, + ) -> Mesh: + uv, indices = self.unwrapper( + self.v_pos, self.v_nrm, self.t_pos_idx, island_padding + ) + + # Do store per vertex UVs. + # This means we need to duplicate some vertices at the seams + individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3) + individual_faces = torch.arange( + individual_vertices.shape[0], + device=individual_vertices.device, + dtype=self.t_pos_idx.dtype, + ).reshape(-1, 3) + uv_flat = uv[indices].reshape((-1, 2)) + # uv_flat[:, 1] = 1 - uv_flat[:, 1] + + self.v_pos = individual_vertices + self.t_pos_idx = individual_faces + self._v_tex = uv_flat + self._v_nrm = self._compute_vertex_normal() + self._v_tng = self._compute_vertex_tangent() + + def _compute_edges(self): + # Compute edges + edges = torch.cat( + [ + self.t_pos_idx[:, [0, 1]], + self.t_pos_idx[:, [1, 2]], + self.t_pos_idx[:, [2, 0]], + ], + dim=0, + ) + edges = edges.sort()[0] + edges = torch.unique(edges, dim=0) + return edges diff --git a/spar3d/models/network.py b/spar3d/models/network.py new file mode 100644 index 0000000000000000000000000000000000000000..22f183a391cc0ae8f0148706a8737776ba30200d --- /dev/null +++ b/spar3d/models/network.py @@ -0,0 +1,223 @@ +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float +from torch import Tensor +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +from spar3d.models.utils import BaseModule, normalize +from spar3d.utils import get_device + + +def conditional_decorator(decorator_with_args, condition, *args, **kwargs): + def wrapper(fn): + if condition: + if len(kwargs) == 0: + return decorator_with_args + return decorator_with_args(*args, **kwargs)(fn) + else: + return fn + + return wrapper + + +class PixelShuffleUpsampleNetwork(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 1024 + out_channels: int = 40 + scale_factor: int = 4 + + conv_layers: int = 4 + conv_kernel_size: int = 3 + + cfg: Config + + def configure(self) -> None: + layers = [] + output_channels = self.cfg.out_channels * self.cfg.scale_factor**2 + + in_channels = self.cfg.in_channels + for i in range(self.cfg.conv_layers): + cur_out_channels = ( + in_channels if i != self.cfg.conv_layers - 1 else output_channels + ) + layers.append( + nn.Conv2d( + in_channels, + cur_out_channels, + self.cfg.conv_kernel_size, + padding=(self.cfg.conv_kernel_size - 1) // 2, + ) + ) + if i != self.cfg.conv_layers - 1: + layers.append(nn.ReLU(inplace=True)) + + layers.append(nn.PixelShuffle(self.cfg.scale_factor)) + + self.upsample = nn.Sequential(*layers) + + def forward( + self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"] + ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]: + return rearrange( + self.upsample( + rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) + ), + "(B Np) Co Hp Wp -> B Np Co Hp Wp", + Np=3, + ) + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @conditional_decorator( + custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32 + ) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @conditional_decorator(custom_bwd, "cuda" in get_device()) + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none" or name == "linear" or name == "identity": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + elif name == "negative": + return lambda x: -x + elif name == "normalize_channel_last": + return lambda x: normalize(x) + elif name == "normalize_channel_first": + return lambda x: normalize(x, dim=1) + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +class LambdaModule(torch.nn.Module): + def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]): + super().__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +def get_activation_module(name) -> torch.nn.Module: + return LambdaModule(get_activation(name)) + + +@dataclass +class HeadSpec: + name: str + out_channels: int + n_hidden_layers: int + output_activation: Optional[str] = None + out_bias: float = 0.0 + + +class MaterialMLP(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 120 + n_neurons: int = 64 + activation: str = "silu" + heads: List[HeadSpec] = field(default_factory=lambda: []) + + cfg: Config + + def configure(self) -> None: + assert len(self.cfg.heads) > 0 + heads = {} + for head in self.cfg.heads: + head_layers = [] + for i in range(head.n_hidden_layers): + head_layers += [ + nn.Linear( + self.cfg.in_channels if i == 0 else self.cfg.n_neurons, + self.cfg.n_neurons, + ), + self.make_activation(self.cfg.activation), + ] + head_layers += [ + nn.Linear( + self.cfg.n_neurons, + head.out_channels, + ), + ] + heads[head.name] = nn.Sequential(*head_layers) + self.heads = nn.ModuleDict(heads) + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + def keys(self): + return self.heads.keys() + + def forward( + self, x, include: Optional[List] = None, exclude: Optional[List] = None + ): + if include is not None and exclude is not None: + raise ValueError("Cannot specify both include and exclude.") + if include is not None: + heads = [h for h in self.cfg.heads if h.name in include] + elif exclude is not None: + heads = [h for h in self.cfg.heads if h.name not in exclude] + else: + heads = self.cfg.heads + + out = { + head.name: get_activation(head.output_activation)( + self.heads[head.name](x) + head.out_bias + ) + for head in heads + } + + return out diff --git a/spar3d/models/tokenizers/dinov2.py b/spar3d/models/tokenizers/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b1c34cacfe52675ac04dc778c9cc5a774192a2 --- /dev/null +++ b/spar3d/models/tokenizers/dinov2.py @@ -0,0 +1,1196 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DINOv2 model.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.dinov2.configuration_dinov2 import Dinov2Config +from transformers.pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base" + + +DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/dinov2-base", + # See all DINOv2 models at https://huggingface.co/models?filter=dinov2 +] + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # register as mask token as it's not used in optimization + # to avoid the use of find_unused_parameters_true + # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_buffer("mask_token", torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), + self.mask_token.to(embeddings.dtype).unsqueeze(0), + embeddings, + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + """ + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + if hasattr(F, "scaled_dot_product_attention"): + assert head_mask is None and not output_attentions + new_size = hidden_states.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2) + value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2) + query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=self.attention_probs_dropout_prob, + is_causal=False, + ) + context_layer = context_layer.transpose(1, 2).reshape( + *hidden_states.size()[:-1], -1 + ) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len( + heads + ) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter( + config.layerscale_value * torch.ones(config.hidden_size) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device + ) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm1_modulation = None + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path1 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2_modulation = None + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + self.drop_path2 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + hidden_states_norm = self.norm1(hidden_states) + if self.norm1_modulation is not None: + assert modulation_cond is not None + hidden_states_norm = self.norm1_modulation( + hidden_states_norm, modulation_cond + ) + self_attention_outputs = self.attention( + hidden_states_norm, # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + if self.norm2_modulation is not None: + assert modulation_cond is not None + layer_output = self.norm2_modulation(layer_output, modulation_cond) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = layer_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module): + self.norm1_modulation = norm1_mod + self.norm2_modulation = norm2_mod + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Dinov2Layer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + modulation_cond, + use_reentrant=False, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, modulation_cond, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + def _set_gradient_checkpointing( + self, module: Dinov2Encoder, value: bool = False + ) -> None: + if isinstance(module, Dinov2Encoder): + module.gradient_checkpointing = value + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@dataclass +class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling): + patch_embeddings: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def expand_input_channels(self, extra_input_channels: int) -> None: + if extra_input_channels == 0: + return + conv_old = self.embeddings.patch_embeddings.projection + conv_new = nn.Conv2d( + self.config.num_channels + extra_input_channels, + self.config.hidden_size, + kernel_size=self.config.patch_size, + stride=self.config.patch_size, + ).to(self.device) + with torch.no_grad(): + conv_new.weight[:, :3] = conv_old.weight + conv_new.bias = conv_old.bias + self.embeddings.patch_embeddings.projection = conv_new + del conv_old + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + modulation_cond=modulation_cond, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return CustomBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + patch_embeddings=embedding_output, + ) + + def set_gradient_checkpointing(self, value: bool = False) -> None: + self._set_gradient_checkpointing(self.encoder, value) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [ + config.hidden_size for _ in range(config.num_hidden_layers + 1) + ] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state[:, 1:, :].reshape( + batch_size, width // patch_size, height // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +class CustomPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__( + self, image_size: int, patch_size: int, num_channels: int, hidden_size: int + ): + super().__init__() + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class CustomEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__( + self, image_size: int, patch_size: int, num_channels: int, hidden_size: int + ) -> None: + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size)) + + self.patch_embeddings = CustomPatchEmbeddings( + image_size, patch_size, num_channels, hidden_size + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, self.hidden_size) + ) + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + return embeddings diff --git a/spar3d/models/tokenizers/image.py b/spar3d/models/tokenizers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1a29686666a8ab66bcb7634703b54eb3016bbe --- /dev/null +++ b/spar3d/models/tokenizers/image.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from jaxtyping import Float +from torch import Tensor + +from spar3d.models.tokenizers.dinov2 import Dinov2Model +from spar3d.models.transformers.attention import Modulation +from spar3d.models.utils import BaseModule + + +class DINOV2SingleImageTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "facebook/dinov2-large" + width: int = 512 + height: int = 512 + modulation_cond_dim: int = 768 + + cfg: Config + + def configure(self) -> None: + self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path) + + for p in self.model.parameters(): + p.requires_grad_(False) + self.model.eval() + + self.model.set_gradient_checkpointing(False) + + # add modulation + modulations = [] + for layer in self.model.encoder.layer: + norm1_modulation = Modulation( + self.model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=True, + single_layer=True, + ) + norm2_modulation = Modulation( + self.model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=True, + single_layer=True, + ) + layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) + modulations += [norm1_modulation, norm2_modulation] + self.modulations = nn.ModuleList(modulations) + + self.register_buffer( + "image_mean", + torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + self.register_buffer( + "image_std", + torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + + def forward( + self, + images: Float[Tensor, "B *N C H W"], + modulation_cond: Optional[Float[Tensor, "B *N Cc"]], + **kwargs, + ) -> Float[Tensor, "B *N Ct Nt"]: + model = self.model + + packed = False + if images.ndim == 4: + packed = True + images = images.unsqueeze(1) + if modulation_cond is not None: + assert modulation_cond.ndim == 2 + modulation_cond = modulation_cond.unsqueeze(1) + + batch_size, n_input_views = images.shape[:2] + images = (images - self.image_mean) / self.image_std + out = model( + rearrange(images, "B N C H W -> (B N) C H W"), + modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") + if modulation_cond is not None + else None, + ) + local_features = out.last_hidden_state + local_features = local_features.permute(0, 2, 1) + local_features = rearrange( + local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size + ) + if packed: + local_features = local_features.squeeze(1) + + return local_features + + def detokenize(self, *args, **kwargs): + raise NotImplementedError diff --git a/spar3d/models/tokenizers/point.py b/spar3d/models/tokenizers/point.py new file mode 100644 index 0000000000000000000000000000000000000000..94b582a1294e272ca866cc7fcac32a23620f2d58 --- /dev/null +++ b/spar3d/models/tokenizers/point.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from jaxtyping import Float +from torch import Tensor + +from spar3d.models.transformers.transformer_1d import Transformer1D +from spar3d.models.utils import BaseModule + + +class TransformerPointTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 64 + in_channels: Optional[int] = 6 + out_channels: Optional[int] = 1024 + num_layers: int = 16 + norm_num_groups: int = 32 + attention_bias: bool = False + activation_fn: str = "geglu" + norm_elementwise_affine: bool = True + + cfg: Config + + def configure(self) -> None: + transformer_cfg = dict(self.cfg.copy()) + # remove the non-transformer configs + transformer_cfg["in_channels"] = ( + self.cfg.num_attention_heads * self.cfg.attention_head_dim + ) + self.model = Transformer1D(transformer_cfg) + self.linear_in = torch.nn.Linear( + self.cfg.in_channels, transformer_cfg["in_channels"] + ) + self.linear_out = torch.nn.Linear( + transformer_cfg["in_channels"], self.cfg.out_channels + ) + + def forward( + self, points: Float[Tensor, "B N Ci"], **kwargs + ) -> Float[Tensor, "B N Cp"]: + assert points.ndim == 3 + inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N + out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci + out = self.linear_out(out) # B N Ci -> B N Co + return out + + def detokenize(self, *args, **kwargs): + raise NotImplementedError diff --git a/spar3d/models/tokenizers/triplane.py b/spar3d/models/tokenizers/triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bb98b5c526d20fbbd8cdc131512ccef0755157 --- /dev/null +++ b/spar3d/models/tokenizers/triplane.py @@ -0,0 +1,49 @@ +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from jaxtyping import Float +from torch import Tensor + +from spar3d.models.utils import BaseModule + + +class TriplaneLearnablePositionalEmbedding(BaseModule): + @dataclass + class Config(BaseModule.Config): + plane_size: int = 96 + num_channels: int = 1024 + + cfg: Config + + def configure(self) -> None: + self.embeddings = nn.Parameter( + torch.randn( + (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), + dtype=torch.float32, + ) + * 1 + / math.sqrt(self.cfg.num_channels) + ) + + def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: + return rearrange( + repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), + "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", + ) + + def detokenize( + self, tokens: Float[Tensor, "B Ct Nt"] + ) -> Float[Tensor, "B 3 Ct Hp Wp"]: + batch_size, Ct, Nt = tokens.shape + assert Nt == self.cfg.plane_size**2 * 3 + assert Ct == self.cfg.num_channels + return rearrange( + tokens, + "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", + Np=3, + Hp=self.cfg.plane_size, + Wp=self.cfg.plane_size, + ) diff --git a/spar3d/models/transformers/attention.py b/spar3d/models/transformers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d4471197e20a12a57ddd7b85e4dd14f554001ecd --- /dev/null +++ b/spar3d/models/transformers/attention.py @@ -0,0 +1,292 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Modulation(nn.Module): + def __init__( + self, + embedding_dim: int, + condition_dim: int, + zero_init: bool = False, + single_layer: bool = False, + ): + super().__init__() + self.silu = nn.SiLU() + if single_layer: + self.linear1 = nn.Identity() + else: + self.linear1 = nn.Linear(condition_dim, condition_dim) + + self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) + + # Only zero init the last linear layer + if zero_init: + nn.init.zeros_(self.linear2.weight) + nn.init.zeros_(self.linear2.bias) + + def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: + emb = self.linear2(self.silu(self.linear1(condition))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.num_heads = heads + self.scale = dim_head**-0.5 + self.dropout = dropout + + # Linear projections + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + # Output projection + self.to_out = nn.ModuleList( + [ + nn.Linear(self.inner_dim, query_dim, bias=out_bias), + nn.Dropout(dropout), + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + + # Project queries, keys, and values + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # Reshape for multi-head attention + query = query.reshape( + batch_size, sequence_length, self.num_heads, -1 + ).transpose(1, 2) + key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose( + 1, 2 + ) + value = value.reshape( + batch_size, sequence_length, self.num_heads, -1 + ).transpose(1, 2) + + # Compute scaled dot product attention + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + scale=self.scale, + ) + + # Reshape and project output + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, sequence_length, self.inner_dim + ) + + # Apply output projection and dropout + for module in self.to_out: + hidden_states = module(hidden_states) + + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "geglu", + attention_bias: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + ): + super().__init__() + + # Self-Attn + self.norm1 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + # Feed-forward + self.norm3 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + self.ff = FeedForward( + dim, + activation_fn=activation_fn, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = ( + self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + ) + + hidden_states + ) + + # Feed-forward + ff_output = self.ff(self.norm3(hidden_states)) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( + dtype=gate.dtype + ) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + linear_cls = nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: + https://arxiv.org/abs/1606.08415. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/spar3d/models/transformers/backbone.py b/spar3d/models/transformers/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..1f735227a6303f97da2a1c8c91616008a2a3edf7 --- /dev/null +++ b/spar3d/models/transformers/backbone.py @@ -0,0 +1,467 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from spar3d.models.transformers.attention import FeedForward +from spar3d.models.utils import BaseModule + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + kv_dim=None, + num_heads=16, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + kv_dim = dim if not kv_dim else kv_dim + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias) + self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x_q, x_kv): + B, N_q, C = x_q.shape + B, N_kv, _ = x_kv.shape + # [B, N_q, C] -> [B, N_q, H, C/H] + q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads) + # [B, N_kv, C] -> [B, N_kv, H, C/H] + k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads) + v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads) + + # attention + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=None, + dropout_p=self.attn_drop, + scale=self.scale, + ).transpose(1, 2) + + # [B, N_q, H, C/H] -> [B, N_q, C] + x = x.reshape(B, N_q, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BasicBlock(nn.Module): + def __init__( + self, + dim: int, + kv_dim: Optional[int] = None, + num_heads: int = 16, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ff_drop: float = 0.0, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn1 = CrossAttention( + dim, + kv_dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.norm2 = nn.LayerNorm(dim) + self.attn2 = CrossAttention( + dim, + kv_dim=kv_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.norm3 = nn.LayerNorm(dim) + self.ff = FeedForward(dim, dropout=ff_drop) + + def forward(self, z, x): + z_norm = self.norm1(z) + z = z + self.attn1(z_norm, z_norm) + # TODO: do we need to have the second attention when x is None? + z_norm = self.norm2(z) + z = z + self.attn2(z_norm, x if x is not None else z_norm) + z_norm = self.norm3(z) + z = z + self.ff(z_norm) + return z + + +class SingleStreamTransformer(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 16 + dropout: float = 0.0 + norm_num_groups: int = 32 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + + cfg: Config + + def configure(self) -> None: + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + # Define input layers + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=1e-6, + affine=True, + ) + self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicBlock( + inner_dim, + kv_dim=self.cfg.cross_attention_dim, + num_heads=self.num_attention_heads, + qkv_bias=self.cfg.attention_bias, + proj_drop=self.cfg.dropout, + ff_drop=self.cfg.dropout, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.proj_in(hidden_states) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states) + hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous() + # TODO: do we really need to add the residual? + hidden_states = hidden_states + residual + return hidden_states + + +class FuseBlock(nn.Module): + """ + Fuse X in to Z with cross attention + """ + + def __init__( + self, + dim_z: int, + dim_x: int, + num_heads: int = 16, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ff_drop: float = 0.0, + norm_x_input: bool = True, + ): + super().__init__() + self.norm_x_input = norm_x_input + if self.norm_x_input: + self.norm_x = nn.LayerNorm(dim_x) + self.attn = CrossAttention( + dim_z, + kv_dim=dim_x, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.norm_z1 = nn.LayerNorm(dim_z) + self.norm_z2 = nn.LayerNorm(dim_z) + self.ff = FeedForward(dim_z, dropout=ff_drop) + + def forward(self, z, x): + # TODO: do we need to normalize x? + z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x) + z = z + self.ff(self.norm_z2(z)) + return z + + +@torch.no_grad() +def get_triplane_attention_mask(res): + N = 3 * res * res + attn_mask = torch.zeros(3, res, res, 3, res, res) + + i, j = torch.meshgrid(torch.arange(res), torch.arange(res)) + + attn_mask[0, i, j, 1, i, :] = 1.0 + attn_mask[0, i, j, 2, j, :] = 1.0 + attn_mask[1, i, j, 0, i, :] = 1.0 + attn_mask[1, i, j, 2, :, j] = 1.0 + attn_mask[2, i, j, 0, :, i] = 1.0 + attn_mask[2, i, j, 1, :, j] = 1.0 + attn_mask = attn_mask.bool() + + attn_bias = torch.empty_like(attn_mask, dtype=torch.float) + attn_bias.masked_fill_(attn_mask, 0.0) + attn_bias.masked_fill_(~attn_mask, float("-inf")) + + return attn_bias.reshape(N, N) + + +class TriplaneAttention(nn.Module): + def __init__( + self, + dim: int, + resolution: int, + num_heads: int = 16, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + full_attention: bool = False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(dim, dim, bias=qkv_bias) + self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.resolution = resolution + self.full_attention = full_attention + self.attn_mask = ( + get_triplane_attention_mask(resolution) if not full_attention else None + ) + + def forward(self, x): + B, N, C = x.shape + # [B, N, C] -> [B, N, H, C/H] + q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads) + v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads) + + # detokenize the planes + assert N == self.resolution**2 * 3 + attn_bias = ( + self.attn_mask.to(q) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, self.num_heads, -1, -1) + if not self.full_attention + else None + ) + + # full attention + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=attn_bias, + dropout_p=self.attn_drop, + scale=self.scale, + ).transpose(1, 2) + + # [B, N_q, H, C/H] -> [B, N_q, C] + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class TwoStreamBlock(nn.Module): + def __init__( + self, + dim_latent: int, + dim_input: int, + num_basic_blocks: int = 4, + num_heads: int = 16, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ff_drop: float = 0.0, + norm_x_input: bool = True, + dim_cross: Optional[int] = None, + ): + super().__init__() + + # Define the fuse block that fuse the input into the latent + self.fuse_block_in = FuseBlock( + dim_latent, + dim_input, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ff_drop=ff_drop, + norm_x_input=norm_x_input, + ) + + # Define the transformer block that process the latent + self.transformer_block = nn.ModuleList( + [ + BasicBlock( + dim_latent, + kv_dim=dim_cross, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_drop=proj_drop, + ff_drop=ff_drop, + ) + for _ in range(num_basic_blocks) + ] + ) + + # Define the fuse block that fuse the latent into the input + self.fuse_block_out = FuseBlock( + dim_input, + dim_latent, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ff_drop=ff_drop, + norm_x_input=norm_x_input, + ) + + def forward(self, latent, input, cross_input): + latent = self.fuse_block_in(latent, input) + for block in self.transformer_block: + latent = block(latent, cross_input) + input = self.fuse_block_out(input, latent) + return latent, input + + +class TwoStreamInterleaveTransformer(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 64 + raw_triplane_channels: int = 1024 + triplane_channels: int = 1024 + raw_image_channels: int = 1024 + num_latents: int = 1792 + num_blocks: int = 4 + num_basic_blocks: int = 3 + dropout: float = 0.0 + latent_init_std: float = 0.02 + norm_num_groups: int = 32 + attention_bias: bool = False + norm_x_input: bool = False + cross_attention_dim: int = 1024 + mix_latent: bool = True + + cfg: Config + + def configure(self) -> None: + self.mix_latent = self.cfg.mix_latent + + # Define the dimensions + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + self.num_latents = self.cfg.num_latents + self.latent_dim = self.num_attention_heads * self.attention_head_dim + + # Define input layers + if self.cfg.norm_num_groups > 0: + self.norm_triplane = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.raw_triplane_channels, + eps=1e-6, + affine=True, + ) + else: + self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels) + self.proj_triplane = nn.Linear( + self.cfg.raw_triplane_channels, self.cfg.triplane_channels + ) + if self.mix_latent: + self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels) + self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim) + self.norm_latent = nn.LayerNorm(self.latent_dim) + self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim) + + # Define the latents + self.latent_init = nn.Parameter( + torch.zeros(1, self.num_latents, self.latent_dim) + ) + nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std) + + # Define the transformer blocks + self.main_blocks = nn.ModuleList( + [ + TwoStreamBlock( + self.latent_dim, + self.cfg.triplane_channels, + num_basic_blocks=self.cfg.num_basic_blocks, + num_heads=self.num_attention_heads, + qkv_bias=self.cfg.attention_bias, + proj_drop=self.cfg.dropout, + ff_drop=self.cfg.dropout, + norm_x_input=self.cfg.norm_x_input, + dim_cross=self.cfg.cross_attention_dim, + ) + for _ in range(self.cfg.num_blocks) + ] + ) + + # 4. Define output layers + self.proj_out = nn.Linear( + self.cfg.triplane_channels, self.cfg.raw_triplane_channels + ) + + def forward(self, hidden_states, encoder_hidden_states, **kwargs): + # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens + # encoder_hidden_states: [B, N_image, image_dim] is the image tokens + if isinstance(self.norm_triplane, nn.GroupNorm): + triplane_tokens = self.norm_triplane(hidden_states) + triplane_tokens = triplane_tokens.permute( + 0, 2, 1 + ) # [B, N_triplane, triplane_dim] + elif isinstance(self.norm_triplane, nn.LayerNorm): + triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1)) + else: + raise ValueError("Unknown normalization layer") + triplane_tokens = self.proj_triplane(triplane_tokens) + if self.mix_latent: + image_tokens = self.norm_image( + encoder_hidden_states + ) # [B, N_image, image_dim] + image_tokens = self.proj_image(image_tokens) + init_latents = self.latent_init.expand( + hidden_states.shape[0], -1, -1 + ) # [B, N_latent_init, latent_dim] + init_latents = self.norm_latent(init_latents) + init_latents = self.proj_latent(init_latents) + if self.mix_latent: + latent_tokens = torch.cat( + [image_tokens, init_latents], dim=1 + ) # [B, N_latent, latent_dim] + else: + latent_tokens = init_latents + + # forward the main blocks + for block in self.main_blocks: + latent_tokens, triplane_tokens = block( + latent_tokens, triplane_tokens, encoder_hidden_states + ) + + # project the triplane tokens back to the original dimension + triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous() + triplane_tokens = triplane_tokens + hidden_states + return triplane_tokens diff --git a/spar3d/models/transformers/point_diffusion.py b/spar3d/models/transformers/point_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0f3b07a97e12c459f7c121c10d3498bc991207 --- /dev/null +++ b/spar3d/models/transformers/point_diffusion.py @@ -0,0 +1,278 @@ +# -------------------------------------------------------- +# Adapted from: https://github.com/openai/point-e +# Licensed under the MIT License +# Copyright (c) 2022 OpenAI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# -------------------------------------------------------- + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from spar3d.models.utils import BaseModule + + +def init_linear(layer, stddev): + nn.init.normal_(layer.weight, std=stddev) + if layer.bias is not None: + nn.init.constant_(layer.bias, 0.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + init_scale: float, + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3) + self.c_proj = nn.Linear(width, width) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + bs, n_ctx, width = x.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(attn_ch) + x = x.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(x, attn_ch, dim=-1) + + x = ( + torch.nn.functional.scaled_dot_product_attention( + q.permute(0, 2, 1, 3), + k.permute(0, 2, 1, 3), + v.permute(0, 2, 1, 3), + scale=scale, + ) + .permute(0, 2, 1, 3) + .reshape(bs, n_ctx, -1) + ) + + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, *, width: int, init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, *, width: int, heads: int, init_scale: float = 1.0): + super().__init__() + + self.attn = MultiheadAttention( + width=width, + heads=heads, + init_scale=init_scale, + ) + self.ln_1 = nn.LayerNorm(width) + self.mlp = MLP(width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + ): + super().__init__() + self.width = width + self.layers = layers + init_scale = init_scale * math.sqrt(1.0 / width) + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + init_scale=init_scale, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class PointDiffusionTransformer(nn.Module): + def __init__( + self, + *, + input_channels: int = 3, + output_channels: int = 3, + width: int = 512, + layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + time_token_cond: bool = False, + ): + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.time_token_cond = time_token_cond + self.time_embed = MLP( + width=width, + init_scale=init_scale * math.sqrt(1.0 / width), + ) + self.ln_pre = nn.LayerNorm(width) + self.backbone = Transformer( + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_post = nn.LayerNorm(width) + self.input_proj = nn.Linear(input_channels, width) + self.output_proj = nn.Linear(width, output_channels) + with torch.no_grad(): + self.output_proj.weight.zero_() + self.output_proj.bias.zero_() + + def forward(self, x: torch.Tensor, t: torch.Tensor): + """ + :param x: an [N x C x T] tensor. + :param t: an [N] tensor. + :return: an [N x C' x T] tensor. + """ + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) + + def _forward_with_cond( + self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] + ) -> torch.Tensor: + h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC + for emb, as_token in cond_as_token: + if not as_token: + h = h + emb[:, None] + extra_tokens = [ + (emb[:, None] if len(emb.shape) == 2 else emb) + for emb, as_token in cond_as_token + if as_token + ] + if len(extra_tokens): + h = torch.cat(extra_tokens + [h], dim=1) + + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + if len(extra_tokens): + h = h[:, sum(h.shape[1] for h in extra_tokens) :] + h = self.output_proj(h) + return h.permute(0, 2, 1) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class PointEDenoiser(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 8 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 12 + width: int = 512 + cond_dim: Optional[int] = None + + cfg: Config + + def configure(self) -> None: + self.denoiser = PointDiffusionTransformer( + input_channels=self.cfg.in_channels, + output_channels=self.cfg.out_channels, + width=self.cfg.width, + layers=self.cfg.num_layers, + heads=self.cfg.num_attention_heads, + init_scale=0.25, + time_token_cond=True, + ) + + self.cond_embed = nn.Sequential( + nn.LayerNorm(self.cfg.cond_dim), + nn.Linear(self.cfg.cond_dim, self.cfg.width), + ) + + def forward( + self, + x, + t, + condition=None, + ): + # renormalize with the per-sample standard deviation + x_std = torch.std(x.reshape(x.shape[0], -1), dim=1, keepdim=True) + x = x / x_std.reshape(-1, *([1] * (len(x.shape) - 1))) + + t_embed = self.denoiser.time_embed( + timestep_embedding(t, self.denoiser.backbone.width) + ) + condition = self.cond_embed(condition) + + cond = [(t_embed, True), (condition, True)] + x_denoised = self.denoiser._forward_with_cond(x, cond) + return x_denoised diff --git a/spar3d/models/transformers/transformer_1d.py b/spar3d/models/transformers/transformer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7394a748955b49344eb4fccef97212a77cfacd --- /dev/null +++ b/spar3d/models/transformers/transformer_1d.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from spar3d.models.transformers.attention import BasicTransformerBlock +from spar3d.models.utils import BaseModule + + +class Transformer1D(BaseModule): + """ + A 1D Transformer model for sequence data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 1 + norm_num_groups: int = 32 + attention_bias: bool = False + activation_fn: str = "geglu" + norm_elementwise_affine: bool = True + residual: bool = True + input_layer_norm: bool = True + norm_eps: float = 1e-5 + + cfg: Config + + def configure(self) -> None: + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + linear_cls = nn.Linear + + # 2. Define input layers + self.in_channels = self.cfg.in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=self.cfg.norm_eps, + affine=True, + ) + self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + self.num_attention_heads, + self.attention_head_dim, + activation_fn=self.cfg.activation_fn, + attention_bias=self.cfg.attention_bias, + norm_elementwise_affine=self.cfg.norm_elementwise_affine, + norm_eps=self.cfg.norm_eps, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = ( + self.cfg.in_channels + if self.cfg.out_channels is None + else self.cfg.out_channels + ) + + self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer1DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch, _, seq_len = hidden_states.shape + residual = hidden_states + + if self.cfg.input_layer_norm: + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 1).reshape( + batch, seq_len, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, seq_len, inner_dim) + .permute(0, 2, 1) + .contiguous() + ) + + if self.cfg.residual: + output = hidden_states + residual + else: + output = hidden_states + + return output diff --git a/spar3d/models/utils.py b/spar3d/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07125c6b57e238c6454e63884ef84289566b487b --- /dev/null +++ b/spar3d/models/utils.py @@ -0,0 +1,240 @@ +import dataclasses +import importlib +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float, Int, Num +from omegaconf import DictConfig, OmegaConf +from torch import Tensor + + +class BaseModule(nn.Module): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + raise NotImplementedError + + +def find_class(cls_string): + module_string = ".".join(cls_string.split(".")[:-1]) + cls_name = cls_string.split(".")[-1] + module = importlib.import_module(module_string, package=None) + cls = getattr(module, cls_name) + return cls + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + # Check if cfg.keys are in fields + cfg_ = cfg.copy() + keys = list(cfg_.keys()) + + field_names = {f.name for f in dataclasses.fields(fields)} + for key in keys: + # This is helpful when swapping out modules from CLI + if key not in field_names: + print(f"Ignoring {key} as it's not supported by {fields}") + cfg_.pop(key) + scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_) + return scfg + + +EPS_DTYPE = { + torch.float16: 1e-4, + torch.bfloat16: 1e-4, + torch.float32: 1e-7, + torch.float64: 1e-8, +} + + +def dot(x, y, dim=-1): + return torch.sum(x * y, dim, keepdim=True) + + +def reflect(x, n): + return x - 2 * dot(x, n) * n + + +def normalize(x, dim=-1, eps=None): + if eps is None: + eps = EPS_DTYPE[x.dtype] + return F.normalize(x, dim=dim, p=2, eps=eps) + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +def dilate_fill(img, mask, iterations=10): + oldMask = mask.float() + oldImg = img + + mask_kernel = torch.ones( + (1, 1, 3, 3), + dtype=oldMask.dtype, + device=oldMask.device, + ) + + for i in range(iterations): + newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1) + + # Fill the extension with mean color of old valid regions + img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1) + mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1) + new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1) + + # Average color of the valid region + mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze( + 2 + ) + # Extend it to the new region + fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1) + + mask_conv = F.conv2d( + newMask, mask_kernel, padding=1 + ) # Get the sum for each kernel patch + newImg = F.fold( + fill_color, (img.shape[-2], img.shape[-1]), (3, 3) + ) / mask_conv.clamp(1) + + diffMask = newMask - oldMask + + oldMask = newMask + oldImg = torch.lerp(oldImg, newImg, diffMask) + + return oldImg + + +def float32_to_uint8_np( + x: Float[np.ndarray, "*B H W C"], + dither: bool = True, + dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None, + dither_strength: float = 1.0, +) -> Int[np.ndarray, "*B H W C"]: + if dither: + dither = ( + dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5 + ) + if dither_mask is not None: + dither = dither * dither_mask + return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8) + return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8) + + +def convert_data(data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + if data.dtype in [torch.float16, torch.bfloat16]: + data = data.float() + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + +class ImageProcessor: + def convert_and_resize( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + size: int, + ): + if isinstance(image, PIL.Image.Image): + image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0) + elif isinstance(image, np.ndarray): + if image.dtype == np.uint8: + image = torch.from_numpy(image.astype(np.float32) / 255.0) + else: + image = torch.from_numpy(image) + elif isinstance(image, torch.Tensor): + pass + + batched = image.ndim >= 4 + view_batch = image.ndim >= 5 + + if view_batch: + image = image.view(-1, *image.shape[2:]) + elif not batched: + image = image[None, ...] + + image = F.interpolate( + image.permute(0, 3, 1, 2), + (size, size), + mode="bilinear", + align_corners=False, + antialias=True, + ).permute(0, 2, 3, 1) + if not batched: + image = image[0] + return image + + def __call__( + self, + image: Union[ + PIL.Image.Image, + np.ndarray, + torch.FloatTensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.FloatTensor], + ], + size: int, + ) -> Any: + if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4: + image = self.convert_and_resize(image, size) + else: + if not isinstance(image, list): + image = [image] + image = [self.convert_and_resize(im, size) for im in image] + image = torch.stack(image, dim=0) + return image + + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) diff --git a/spar3d/system.py b/spar3d/system.py new file mode 100644 index 0000000000000000000000000000000000000000..348781287581bb3ca3a8a0999f9f921e24db1c5d --- /dev/null +++ b/spar3d/system.py @@ -0,0 +1,717 @@ +import os +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import trimesh +from einops import rearrange +from huggingface_hub import hf_hub_download +from jaxtyping import Float +from omegaconf import OmegaConf +from PIL import Image +from safetensors.torch import load_model +from torch import Tensor + +from spar3d.models.diffusion.gaussian_diffusion import ( + SpacedDiffusion, + get_named_beta_schedule, + space_timesteps, +) +from spar3d.models.diffusion.sampler import PointCloudSampler +from spar3d.models.isosurface import MarchingTetrahedraHelper +from spar3d.models.mesh import Mesh +from spar3d.models.utils import ( + BaseModule, + ImageProcessor, + convert_data, + dilate_fill, + find_class, + float32_to_uint8_np, + normalize, + scale_tensor, +) +from spar3d.utils import ( + create_intrinsic_from_fov_rad, + default_cond_c2w, + get_device, + normalize_pc_bbox, +) + +try: + from texture_baker import TextureBaker +except ImportError: + import logging + + logging.warning( + "Could not import texture_baker. Please install it via `pip install texture-baker/`" + ) + # Exit early to avoid further errors + raise ImportError("texture_baker not found") + + +class SPAR3D(BaseModule): + @dataclass + class Config(BaseModule.Config): + cond_image_size: int + isosurface_resolution: int + isosurface_threshold: float = 10.0 + radius: float = 1.0 + background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5]) + default_fovy_rad: float = 0.591627 + default_distance: float = 2.2 + + camera_embedder_cls: str = "" + camera_embedder: dict = field(default_factory=dict) + + image_tokenizer_cls: str = "" + image_tokenizer: dict = field(default_factory=dict) + + point_embedder_cls: str = "" + point_embedder: dict = field(default_factory=dict) + + tokenizer_cls: str = "" + tokenizer: dict = field(default_factory=dict) + + backbone_cls: str = "" + backbone: dict = field(default_factory=dict) + + post_processor_cls: str = "" + post_processor: dict = field(default_factory=dict) + + decoder_cls: str = "" + decoder: dict = field(default_factory=dict) + + image_estimator_cls: str = "" + image_estimator: dict = field(default_factory=dict) + + global_estimator_cls: str = "" + global_estimator: dict = field(default_factory=dict) + + # Point diffusion modules + pdiff_camera_embedder_cls: str = "" + pdiff_camera_embedder: dict = field(default_factory=dict) + + pdiff_image_tokenizer_cls: str = "" + pdiff_image_tokenizer: dict = field(default_factory=dict) + + pdiff_backbone_cls: str = "" + pdiff_backbone: dict = field(default_factory=dict) + + scale_factor_xyz: float = 1.0 + scale_factor_rgb: float = 1.0 + bias_xyz: float = 0.0 + bias_rgb: float = 0.0 + train_time_steps: int = 1024 + inference_time_steps: int = 64 + + mean_type: str = "epsilon" + var_type: str = "fixed_small" + diffu_sched: str = "cosine" + diffu_sched_exp: float = 12.0 + guidance_scale: float = 3.0 + sigma_max: float = 120.0 + s_churn: float = 3.0 + + cfg: Config + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str + ): + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if os.path.isdir(os.path.join(base_dir, pretrained_model_name_or_path)): + config_path = os.path.join( + base_dir, pretrained_model_name_or_path, config_name + ) + weight_path = os.path.join( + base_dir, pretrained_model_name_or_path, weight_name + ) + else: + config_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=config_name + ) + weight_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=weight_name + ) + + cfg = OmegaConf.load(config_path) + OmegaConf.resolve(cfg) + model = cls(cfg) + load_model(model, weight_path, strict=False) + return model + + @property + def device(self): + return next(self.parameters()).device + + def configure(self): + self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( + self.cfg.image_tokenizer + ) + self.point_embedder = find_class(self.cfg.point_embedder_cls)( + self.cfg.point_embedder + ) + self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) + self.camera_embedder = find_class(self.cfg.camera_embedder_cls)( + self.cfg.camera_embedder + ) + self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) + self.post_processor = find_class(self.cfg.post_processor_cls)( + self.cfg.post_processor + ) + self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) + self.image_estimator = find_class(self.cfg.image_estimator_cls)( + self.cfg.image_estimator + ) + self.global_estimator = find_class(self.cfg.global_estimator_cls)( + self.cfg.global_estimator + ) + + # point diffusion modules + self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)( + self.cfg.pdiff_image_tokenizer + ) + self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)( + self.cfg.pdiff_camera_embedder + ) + self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)( + self.cfg.pdiff_backbone + ) + + self.bbox: Float[Tensor, "2 3"] + self.register_buffer( + "bbox", + torch.as_tensor( + [ + [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], + [self.cfg.radius, self.cfg.radius, self.cfg.radius], + ], + dtype=torch.float32, + ), + ) + self.isosurface_helper = MarchingTetrahedraHelper( + self.cfg.isosurface_resolution, + os.path.join( + os.path.dirname(__file__), + "..", + "load", + "tets", + f"{self.cfg.isosurface_resolution}_tets.npz", + ), + ) + + self.baker = TextureBaker() + self.image_processor = ImageProcessor() + + channel_scales = [self.cfg.scale_factor_xyz] * 3 + channel_scales += [self.cfg.scale_factor_rgb] * 3 + channel_biases = [self.cfg.bias_xyz] * 3 + channel_biases += [self.cfg.bias_rgb] * 3 + channel_scales = np.array(channel_scales) + channel_biases = np.array(channel_biases) + + betas = get_named_beta_schedule( + self.cfg.diffu_sched, self.cfg.train_time_steps, self.cfg.diffu_sched_exp + ) + + diffusion_kwargs = dict( + betas=betas, + model_mean_type=self.cfg.mean_type, + model_var_type=self.cfg.var_type, + channel_scales=channel_scales, + channel_biases=channel_biases, + ) + self.diffusion_spaced = SpacedDiffusion( + use_timesteps=space_timesteps( + self.cfg.train_time_steps, + "ddim" + str(self.cfg.inference_time_steps), + ), + **diffusion_kwargs, + ) + self.sampler = PointCloudSampler( + model=self.pdiff_backbone, + diffusion=self.diffusion_spaced, + num_points=512, + point_dim=6, + guidance_scale=self.cfg.guidance_scale, + clip_denoised=True, + sigma_min=1e-3, + sigma_max=self.cfg.sigma_max, + s_churn=self.cfg.s_churn, + ) + + def triplane_to_meshes( + self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"] + ) -> list[Mesh]: + meshes = [] + for i in range(triplanes.shape[0]): + triplane = triplanes[i] + grid_vertices = scale_tensor( + self.isosurface_helper.grid_vertices.to(triplanes.device), + self.isosurface_helper.points_range, + self.bbox, + ) + + values = self.query_triplane(grid_vertices, triplane) + decoded = self.decoder(values, include=["vertex_offset", "density"]) + sdf = decoded["density"] - self.cfg.isosurface_threshold + + deform = decoded["vertex_offset"].squeeze(0) + + mesh: Mesh = self.isosurface_helper( + sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None + ) + mesh.v_pos = scale_tensor( + mesh.v_pos, self.isosurface_helper.points_range, self.bbox + ) + + meshes.append(mesh) + + return meshes + + def query_triplane( + self, + positions: Float[Tensor, "*B N 3"], + triplanes: Float[Tensor, "*B 3 Cp Hp Wp"], + ) -> Float[Tensor, "*B N F"]: + batched = positions.ndim == 3 + if not batched: + # no batch dimension + triplanes = triplanes[None, ...] + positions = positions[None, ...] + assert triplanes.ndim == 5 and positions.ndim == 3 + + positions = scale_tensor( + positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) + ) + + indices2D: Float[Tensor, "B 3 N 2"] = torch.stack( + (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]), + dim=-3, + ).to(triplanes.dtype) + out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample( + rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(), + rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(), + align_corners=True, + mode="bilinear", + ) + out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3) + + return out + + def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]: + # if batch[rgb_cond] is only one view, add a view dimension + if len(batch["rgb_cond"].shape) == 4: + batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1) + batch["mask_cond"] = batch["mask_cond"].unsqueeze(1) + batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1) + batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1) + batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1) + + batch_size, n_input_views = batch["rgb_cond"].shape[:2] + + camera_embeds: Optional[Float[Tensor, "B Nv Cc"]] + camera_embeds = self.camera_embedder(**batch) + + pc_embeds = self.point_embedder(batch["pc_cond"]) + + input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer( + rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"), + modulation_cond=camera_embeds, + ) + + input_image_tokens = rearrange( + input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views + ) + + tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size) + + cross_tokens = input_image_tokens + cross_tokens = torch.cat([cross_tokens, pc_embeds], dim=1) + + tokens = self.backbone( + tokens, + encoder_hidden_states=cross_tokens, + modulation_cond=None, + ) + + direct_codes = self.tokenizer.detokenize(tokens) + scene_codes = self.post_processor(direct_codes) + return scene_codes, direct_codes + + def forward_pdiff_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]: + if len(batch["rgb_cond"].shape) == 4: + batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1) + batch["mask_cond"] = batch["mask_cond"].unsqueeze(1) + batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1) + batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1) + batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1) + + _batch_size, n_input_views = batch["rgb_cond"].shape[:2] + + # Camera modulation + camera_embeds: Float[Tensor, "B Nv Cc"] = self.pdiff_camera_embedder(**batch) + + input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.pdiff_image_tokenizer( + rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"), + modulation_cond=camera_embeds, + ) + + input_image_tokens = rearrange( + input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views + ) + + return input_image_tokens + + def run_image( + self, + image: Union[Image.Image, List[Image.Image]], + bake_resolution: int, + pointcloud: Optional[Union[List[np.ndarray], np.ndarray, Tensor]] = None, + remesh: Literal["none", "triangle", "quad"] = "none", + vertex_count: int = -1, + estimate_illumination: bool = False, + return_points: bool = False, + ) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]: + if isinstance(image, list): + rgb_cond = [] + mask_cond = [] + for img in image: + mask, rgb = self.prepare_image(img) + mask_cond.append(mask) + rgb_cond.append(rgb) + rgb_cond = torch.stack(rgb_cond, 0) + mask_cond = torch.stack(mask_cond, 0) + batch_size = rgb_cond.shape[0] + else: + mask_cond, rgb_cond = self.prepare_image(image) + batch_size = 1 + + c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device) + intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_rad( + self.cfg.default_fovy_rad, + self.cfg.cond_image_size, + self.cfg.cond_image_size, + ) + + batch = { + "rgb_cond": rgb_cond, + "mask_cond": mask_cond, + "c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1), + "intrinsic_cond": intrinsic.to(self.device) + .view(1, 1, 3, 3) + .repeat(batch_size, 1, 1, 1), + "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device) + .view(1, 1, 3, 3) + .repeat(batch_size, 1, 1, 1), + } + + meshes, global_dict = self.generate_mesh( + batch, + bake_resolution, + pointcloud, + remesh, + vertex_count, + estimate_illumination, + ) + + if return_points: + point_clouds = [] + for i in range(batch_size): + xyz = batch["pc_cond"][i, :, :3].cpu().numpy() + color_rgb = ( + (batch["pc_cond"][i, :, 3:6] * 255).cpu().numpy().astype(np.uint8) + ) + pc_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb) + point_clouds.append(pc_trimesh) + global_dict["point_clouds"] = point_clouds + + if batch_size == 1: + return meshes[0], global_dict + else: + return meshes, global_dict + + def prepare_image(self, image): + if image.mode != "RGBA": + raise ValueError("Image must be in RGBA mode") + img_cond = ( + torch.from_numpy( + np.asarray( + image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size)) + ).astype(np.float32) + / 255.0 + ) + .float() + .clip(0, 1) + .to(self.device) + ) + mask_cond = img_cond[:, :, -1:] + rgb_cond = torch.lerp( + torch.tensor(self.cfg.background_color, device=self.device)[None, None, :], + img_cond[:, :, :3], + mask_cond, + ) + + return mask_cond, rgb_cond + + def generate_mesh( + self, + batch, + bake_resolution: int, + pointcloud: Optional[Union[List[float], np.ndarray, Tensor]] = None, + remesh: Literal["none", "triangle", "quad"] = "none", + vertex_count: int = -1, + estimate_illumination: bool = False, + ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]: + batch["rgb_cond"] = self.image_processor( + batch["rgb_cond"], self.cfg.cond_image_size + ) + batch["mask_cond"] = self.image_processor( + batch["mask_cond"], self.cfg.cond_image_size + ) + + batch_size = batch["rgb_cond"].shape[0] + + if pointcloud is not None: + if isinstance(pointcloud, list): + cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6) + xyz = cond_tensor[:, :3] + color_rgb = cond_tensor[:, 3:] + # Check if point cloud is a numpy array + elif isinstance(pointcloud, np.ndarray): + xyz = torch.tensor(pointcloud[:, :3]).float().cuda() + color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda() + else: + raise ValueError("Invalid point cloud type") + + pointcloud = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(0) + batch["pc_cond"] = pointcloud + + if "pc_cond" not in batch: + cond_tokens = self.forward_pdiff_cond(batch) + sample_iter = self.sampler.sample_batch_progressive( + batch_size, cond_tokens, device=self.device + ) + for x in sample_iter: + samples = x["xstart"] + + denoised_pc = samples.permute(0, 2, 1).float() # [B, C, N] -> [B, N, C] + denoised_pc = normalize_pc_bbox(denoised_pc) + + # predict the full 3D conditioned on the denoised point cloud + batch["pc_cond"] = denoised_pc + + scene_codes, non_postprocessed_codes = self.get_scene_codes(batch) + + global_dict = {} + if self.image_estimator is not None: + global_dict.update( + self.image_estimator( + torch.cat([batch["rgb_cond"], batch["mask_cond"]], dim=-1) + ) + ) + if self.global_estimator is not None and estimate_illumination: + global_dict.update(self.global_estimator(non_postprocessed_codes)) + + global_dict["pointcloud"] = batch["pc_cond"] + + device = get_device() + with torch.no_grad(): + with ( + torch.autocast(device_type=device, enabled=False) + if "cuda" in device + else nullcontext() + ): + meshes = self.triplane_to_meshes(scene_codes) + + rets = [] + for i, mesh in enumerate(meshes): + # Check for empty mesh + if mesh.v_pos.shape[0] == 0: + rets.append(trimesh.Trimesh()) + continue + + if remesh == "triangle": + mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count) + elif remesh == "quad": + mesh = mesh.quad_remesh(quad_vertex_count=vertex_count) + else: + if vertex_count > 0: + print( + "Warning: vertex_count is ignored when remesh is none" + ) + + if remesh != "none": + print( + f"After {remesh} remesh the mesh has {mesh.v_pos.shape[0]} verts and {mesh.t_pos_idx.shape[0]} faces", + ) + mesh.unwrap_uv() + + # Build textures + rast = self.baker.rasterize( + mesh.v_tex, mesh.t_pos_idx, bake_resolution + ) + bake_mask = self.baker.get_mask(rast) + + pos_bake = self.baker.interpolate( + mesh.v_pos, + rast, + mesh.t_pos_idx, + ) + gb_pos = pos_bake[bake_mask] + + tri_query = self.query_triplane(gb_pos, scene_codes[i])[0] + decoded = self.decoder( + tri_query, exclude=["density", "vertex_offset"] + ) + + nrm = self.baker.interpolate( + mesh.v_nrm, + rast, + mesh.t_pos_idx, + ) + gb_nrm = F.normalize(nrm[bake_mask], dim=-1) + decoded["normal"] = gb_nrm + + # Check if any keys in global_dict start with decoded_ + for k, v in global_dict.items(): + if k.startswith("decoder_"): + decoded[k.replace("decoder_", "")] = v[i] + + mat_out = { + "albedo": decoded["features"], + "roughness": decoded["roughness"], + "metallic": decoded["metallic"], + "normal": normalize(decoded["perturb_normal"]), + "bump": None, + } + + for k, v in mat_out.items(): + if v is None: + continue + if v.shape[0] == 1: + # Skip and directly add a single value + mat_out[k] = v[0] + else: + f = torch.zeros( + bake_resolution, + bake_resolution, + v.shape[-1], + dtype=v.dtype, + device=v.device, + ) + if v.shape == f.shape: + continue + if k == "normal": + # Use un-normalized tangents here so that larger smaller tris + # Don't effect the tangents that much + tng = self.baker.interpolate( + mesh.v_tng, + rast, + mesh.t_pos_idx, + ) + gb_tng = tng[bake_mask] + gb_tng = F.normalize(gb_tng, dim=-1) + gb_btng = F.normalize( + torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1 + ) + normal = F.normalize(mat_out["normal"], dim=-1) + + # Create tangent space matrix and transform normal + tangent_matrix = torch.stack( + [gb_tng, gb_btng, gb_nrm], dim=-1 + ) + normal_tangent = torch.bmm( + tangent_matrix.transpose(1, 2), normal.unsqueeze(-1) + ).squeeze(-1) + + # Convert from [-1,1] to [0,1] range for storage + normal_tangent = (normal_tangent * 0.5 + 0.5).clamp( + 0, 1 + ) + + f[bake_mask] = normal_tangent.view(-1, 3) + mat_out["bump"] = f + else: + f[bake_mask] = v.view(-1, v.shape[-1]) + mat_out[k] = f + + def uv_padding(arr): + if arr.ndim == 1: + return arr + return ( + dilate_fill( + arr.permute(2, 0, 1)[None, ...].contiguous(), + bake_mask.unsqueeze(0).unsqueeze(0), + iterations=bake_resolution // 150, + ) + .squeeze(0) + .permute(1, 2, 0) + .contiguous() + ) + + verts_np = convert_data(mesh.v_pos) + faces = convert_data(mesh.t_pos_idx) + uvs = convert_data(mesh.v_tex) + + basecolor_tex = Image.fromarray( + float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"]))) + ).convert("RGB") + basecolor_tex.format = "JPEG" + + metallic = mat_out["metallic"].squeeze().cpu().item() + roughness = mat_out["roughness"].squeeze().cpu().item() + + if "bump" in mat_out and mat_out["bump"] is not None: + bump_np = convert_data(uv_padding(mat_out["bump"])) + bump_up = np.ones_like(bump_np) + bump_up[..., :2] = 0.5 + bump_up[..., 2:] = 1 + bump_tex = Image.fromarray( + float32_to_uint8_np( + bump_np, + dither=True, + # Do not dither if something is perfectly flat + dither_mask=np.all( + bump_np == bump_up, axis=-1, keepdims=True + ).astype(np.float32), + ) + ).convert("RGB") + bump_tex.format = ( + "JPEG" # PNG would be better but the assets are larger + ) + else: + bump_tex = None + + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=basecolor_tex, + roughnessFactor=roughness, + metallicFactor=metallic, + normalTexture=bump_tex, + ) + + tmesh = trimesh.Trimesh( + vertices=verts_np, + faces=faces, + visual=trimesh.visual.texture.TextureVisuals( + uv=uvs, material=material + ), + ) + rot = trimesh.transformations.rotation_matrix( + np.radians(-90), [1, 0, 0] + ) + tmesh.apply_transform(rot) + tmesh.apply_transform( + trimesh.transformations.rotation_matrix( + np.radians(90), [0, 1, 0] + ) + ) + + tmesh.invert() + + rets.append(tmesh) + + return rets, global_dict diff --git a/spar3d/utils.py b/spar3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8fdc01762f314bd8d0b0d0e956f90ca43046d52c --- /dev/null +++ b/spar3d/utils.py @@ -0,0 +1,143 @@ +import os + +import numpy as np +import torch +import torchvision.transforms.functional as torchvision_F +from PIL import Image +from transparent_background import Remover + +import spar3d.models.utils as spar3d_utils + + +def get_device(): + if os.environ.get("SF3D_USE_CPU", "0") == "1": + return "cpu" + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + return device + + +def create_intrinsic_from_fov_rad(fov_rad: float, cond_height: int, cond_width: int): + intrinsic = spar3d_utils.get_intrinsic_from_fov( + fov_rad, + H=cond_height, + W=cond_width, + ) + intrinsic_normed_cond = intrinsic.clone() + intrinsic_normed_cond[..., 0, 2] /= cond_width + intrinsic_normed_cond[..., 1, 2] /= cond_height + intrinsic_normed_cond[..., 0, 0] /= cond_width + intrinsic_normed_cond[..., 1, 1] /= cond_height + + return intrinsic, intrinsic_normed_cond + + +def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int): + return create_intrinsic_from_fov_rad(np.deg2rad(fov_deg), cond_height, cond_width) + + +def default_cond_c2w(distance: float): + c2w_cond = torch.as_tensor( + [ + [0, 0, 1, distance], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ] + ).float() + return c2w_cond + + +def normalize_pc_bbox(pc, scale=1.0): + # get the bounding box of the mesh + assert len(pc.shape) in [2, 3] and pc.shape[-1] in [3, 6, 9] + n_dim = len(pc.shape) + device = pc.device + pc = pc.cpu() + if n_dim == 2: + pc = pc.unsqueeze(0) + normalize_pc = [] + for b in range(pc.shape[0]): + xyz = pc[b, :, :3] # [N, 3] + bound_x = (xyz[:, 0].max(), xyz[:, 0].min()) + bound_y = (xyz[:, 1].max(), xyz[:, 1].min()) + bound_z = (xyz[:, 2].max(), xyz[:, 2].min()) + # get the center of the bounding box + center = np.array( + [ + (bound_x[0] + bound_x[1]) / 2, + (bound_y[0] + bound_y[1]) / 2, + (bound_z[0] + bound_z[1]) / 2, + ] + ) + # get the largest dimension of the bounding box + scale = max( + bound_x[0] - bound_x[1], bound_y[0] - bound_y[1], bound_z[0] - bound_z[1] + ) + xyz = (xyz - center) / scale + extra = pc[b, :, 3:] + normalize_pc.append(torch.cat([xyz, extra], dim=-1)) + return ( + torch.stack(normalize_pc, dim=0).to(device) + if n_dim == 3 + else normalize_pc[0].to(device) + ) + + +def remove_background( + image: Image, + bg_remover: Remover = None, + force: bool = False, + **transparent_background_kwargs, +) -> Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = bg_remover.process( + image.convert("RGB"), **transparent_background_kwargs + ) + return image + + +def get_1d_bounds(arr): + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + + +def get_bbox_from_mask(mask, thr=0.5): + masks_for_box = (mask > thr).astype(np.float32) + assert masks_for_box.sum() > 0, "Empty mask!" + x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1)) + return x0, y0, x1, y1 + + +def foreground_crop(image_rgba, crop_ratio=1.3, newsize=None, no_crop=False): + # make sure the image is a PIL image in RGBA mode + assert image_rgba.mode == "RGBA", "Image must be in RGBA mode!" + if not no_crop: + mask_np = np.array(image_rgba)[:, :, -1] + mask_np = (mask_np >= 1).astype(np.float32) + x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5) + h, w = y2 - y1, x2 - x1 + yc, xc = (y1 + y2) / 2, (x1 + x2) / 2 + scale = max(h, w) * crop_ratio + image = torchvision_F.crop( + image_rgba, + top=int(yc - scale / 2), + left=int(xc - scale / 2), + height=int(scale), + width=int(scale), + ) + else: + image = image_rgba + # resize if needed + if newsize is not None: + image = image.resize(newsize) + return image diff --git a/texture_baker/README.md b/texture_baker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d7bde117fd1ecd30e193163b50fb825b64f66590 --- /dev/null +++ b/texture_baker/README.md @@ -0,0 +1,26 @@ +# Texture baker + +Small texture baker which rasterizes barycentric coordinates to a tensor. +It also implements an interpolation module which can be used to bake attributes to textures then. + +## Usage + +The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates. +It supports baking on the CPU and GPU. + +```python +from texture_baker import TextureBaker + +mesh = ... +uv = mesh.uv # num_vertex, 2 +triangle_idx = mesh.faces # num_faces, 3 +vertices = mesh.vertices # num_vertex, 3 + +tb = TextureBaker() +# First get the barycentric coordinates +rast = tb.rasterize( + uv=uv, face_indices=triangle_idx, bake_resolution=1024 +) +# Then interpolate vertex attributes +position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx) +``` diff --git a/texture_baker/requirements.txt b/texture_baker/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..af3149eb479c47955cf0d40d253890baa18d2f54 --- /dev/null +++ b/texture_baker/requirements.txt @@ -0,0 +1,2 @@ +torch +numpy diff --git a/texture_baker/setup.py b/texture_baker/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..88741298360d0a22a982309e8167af36dca0a5ce --- /dev/null +++ b/texture_baker/setup.py @@ -0,0 +1,127 @@ +import glob +import os +import platform + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import ( + CUDA_HOME, + BuildExtension, + CppExtension, + CUDAExtension, +) + +library_name = "texture_baker" + + +def get_extensions(): + debug_mode = os.getenv("DEBUG", "0") == "1" + use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1" + use_metal = ( + os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1" + ) + use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1" + if debug_mode: + print("Compiling in debug mode") + + use_cuda = use_cuda and CUDA_HOME is not None + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + "-fopenmp", + ] + + ["-march=native"] + if use_native_arch + else [], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + ], + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + if platform.system() == "Windows": + extra_compile_args["cxx"].append("/Z7") + extra_compile_args["cxx"].append("/Od") + extra_link_args.extend(["/DEBUG"]) + extra_compile_args["cxx"].append("-UNDEBUG") + extra_compile_args["nvcc"].append("-UNDEBUG") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + define_macros = [] + extensions = [] + libraries = [] + + this_dir = os.path.dirname(os.path.curdir) + sources = glob.glob( + os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True + ) + + if len(sources) == 0: + print("No source files found for extension, skipping extension compilation") + return None + + if use_cuda: + define_macros += [ + ("THRUST_IGNORE_CUB_VERSION_CHECK", None), + ] + sources += glob.glob( + os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True + ) + libraries += ["cudart", "c10_cuda"] + + if use_metal: + define_macros += [ + ("WITH_MPS", None), + ] + sources += glob.glob( + os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True + ) + extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]}) + extra_link_args += ["-arch", "arm64"] + + extensions.append( + extension( + name=f"{library_name}._C", + sources=sources, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + libraries=libraries + + [ + "c10", + "torch", + "torch_cpu", + "torch_python", + ], + ) + ) + + for ext in extensions: + ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries] + + print(extensions) + + return extensions + + +setup( + name=library_name, + version="0.0.1", + packages=find_packages(where="."), + package_dir={"": "."}, + ext_modules=get_extensions(), + install_requires=[], + package_data={ + library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")], + }, + description="Small texture baker which rasterizes barycentric coordinates to a tensor.", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/Stability-AI/texture_baker", + cmdclass={"build_ext": BuildExtension}, +) diff --git a/texture_baker/texture_baker/__init__.py b/texture_baker/texture_baker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6173e6901823c2a00ca4e48b11e9649df977d53 --- /dev/null +++ b/texture_baker/texture_baker/__init__.py @@ -0,0 +1,4 @@ +import torch # noqa: F401 + +from . import _C # noqa: F401 +from .baker import TextureBaker # noqa: F401 diff --git a/texture_baker/texture_baker/baker.py b/texture_baker/texture_baker/baker.py new file mode 100644 index 0000000000000000000000000000000000000000..61e4425a2d18d7aece86e5f12199754aeb4d5bb0 --- /dev/null +++ b/texture_baker/texture_baker/baker.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from torch import Tensor + + +class TextureBaker(nn.Module): + def __init__(self): + super().__init__() + + def rasterize( + self, + uv: Tensor, + face_indices: Tensor, + bake_resolution: int, + ) -> Tensor: + """ + Rasterize the UV coordinates to a barycentric coordinates + & Triangle idxs texture map + + Args: + uv (Tensor, num_vertices 2, float): UV coordinates of the mesh + face_indices (Tensor, num_faces 3, int): Face indices of the mesh + bake_resolution (int): Resolution of the bake + + Returns: + Tensor, bake_resolution bake_resolution 4, float: Rasterized map + """ + return torch.ops.texture_baker_cpp.rasterize( + uv, face_indices.to(torch.int32), bake_resolution + ) + + def get_mask(self, rast: Tensor) -> Tensor: + """ + Get the occupancy mask from the rasterized map + + Args: + rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map + + Returns: + Tensor, bake_resolution bake_resolution, bool: Mask + """ + return rast[..., -1] >= 0 + + def interpolate( + self, + attr: Tensor, + rast: Tensor, + face_indices: Tensor, + ) -> Tensor: + """ + Interpolate the attributes using the rasterized map + + Args: + attr (Tensor, num_vertices 3, float): Attributes of the mesh + rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map + face_indices (Tensor, num_faces 3, int): Face indices of the mesh + uv (Tensor, num_vertices 2, float): UV coordinates of the mesh + + Returns: + Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes + """ + return torch.ops.texture_baker_cpp.interpolate( + attr, face_indices.to(torch.int32), rast + ) + + def forward( + self, + attr: Tensor, + uv: Tensor, + face_indices: Tensor, + bake_resolution: int, + ) -> Tensor: + """ + Bake the texture + + Args: + attr (Tensor, num_vertices 3, float): Attributes of the mesh + uv (Tensor, num_vertices 2, float): UV coordinates of the mesh + face_indices (Tensor, num_faces 3, int): Face indices of the mesh + bake_resolution (int): Resolution of the bake + + Returns: + Tensor, bake_resolution bake_resolution 3, float: Baked texture + """ + rast = self.rasterize(uv, face_indices, bake_resolution) + return self.interpolate(attr, rast, face_indices, uv) diff --git a/texture_baker/texture_baker/csrc/baker.cpp b/texture_baker/texture_baker/csrc/baker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f92766489d4e25b4df76bfd1516b502c2e179aaa --- /dev/null +++ b/texture_baker/texture_baker/csrc/baker.cpp @@ -0,0 +1,548 @@ +#include +#include +#include +#include +#include +#include +#ifndef __ARM_ARCH_ISA_A64 +#include +#endif + +#include "baker.h" + +// #define TIMING +#define BINS 8 + +namespace texture_baker_cpp { +// Calculate the centroid of a triangle +tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1, + const tb_float2 &v2) { + return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f}; +} + +float BVH::find_best_split_plane(const BVHNode &node, int &best_axis, + int &best_pos, AABB ¢roidBounds) { + float best_cost = std::numeric_limits::max(); + + for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y + { + float boundsMin = centroidBounds.min[axis]; + float boundsMax = centroidBounds.max[axis]; + if (boundsMin == boundsMax) { + continue; + } + + // Populate the bins + float scale = BINS / (boundsMax - boundsMin); + float leftCountArea[BINS - 1], rightCountArea[BINS - 1]; + int leftSum = 0, rightSum = 0; + +#ifndef __ARM_ARCH_ISA_A64 +#ifndef _MSC_VER + if (__builtin_cpu_supports("sse")) +#elif (defined(_M_AMD64) || defined(_M_X64)) + // SSE supported on Windows + if constexpr (true) +#endif + { + __m128 min4[BINS], max4[BINS]; + unsigned int count[BINS]; + for (unsigned int i = 0; i < BINS; i++) + min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f), + count[i] = 0; + for (int i = node.start; i < node.end; i++) { + int tri_idx = triangle_indices[i]; + const Triangle &triangle = triangles[tri_idx]; + + int binIdx = std::min( + BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale)); + count[binIdx]++; + __m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f); + __m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f); + __m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f); + min4[binIdx] = _mm_min_ps(min4[binIdx], v0); + max4[binIdx] = _mm_max_ps(max4[binIdx], v0); + min4[binIdx] = _mm_min_ps(min4[binIdx], v1); + max4[binIdx] = _mm_max_ps(max4[binIdx], v1); + min4[binIdx] = _mm_min_ps(min4[binIdx], v2); + max4[binIdx] = _mm_max_ps(max4[binIdx], v2); + } + // gather data for the 7 planes between the 8 bins + __m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4; + __m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4; + for (int i = 0; i < BINS - 1; i++) { + leftSum += count[i]; + rightSum += count[BINS - 1 - i]; + leftMin4 = _mm_min_ps(leftMin4, min4[i]); + rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]); + leftMax4 = _mm_max_ps(leftMax4, max4[i]); + rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]); + float le[4], re[4]; + _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4)); + _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4)); + // SSE order goes from back to front + leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation + rightCountArea[BINS - 2 - i] = + rightSum * (re[2] * re[3]); // 2D area calculation + } + } +#else + if constexpr (false) { + } +#endif + else { + struct Bin { + AABB bounds; + int triCount = 0; + } bins[BINS]; + + for (int i = node.start; i < node.end; i++) { + int tri_idx = triangle_indices[i]; + const Triangle &triangle = triangles[tri_idx]; + + int binIdx = std::min( + BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale)); + bins[binIdx].triCount++; + bins[binIdx].bounds.grow(triangle.v0); + bins[binIdx].bounds.grow(triangle.v1); + bins[binIdx].bounds.grow(triangle.v2); + } + + // Gather data for the planes between the bins + AABB leftBox, rightBox; + + for (int i = 0; i < BINS - 1; i++) { + leftSum += bins[i].triCount; + leftBox.grow(bins[i].bounds); + leftCountArea[i] = leftSum * leftBox.area(); + + rightSum += bins[BINS - 1 - i].triCount; + rightBox.grow(bins[BINS - 1 - i].bounds); + rightCountArea[BINS - 2 - i] = rightSum * rightBox.area(); + } + } + + // Calculate SAH cost for the planes + scale = (boundsMax - boundsMin) / BINS; + for (int i = 0; i < BINS - 1; i++) { + float planeCost = leftCountArea[i] + rightCountArea[i]; + if (planeCost < best_cost) { + best_axis = axis; + best_pos = i + 1; + best_cost = planeCost; + } + } + } + + return best_cost; +} + +void BVH::update_node_bounds(BVHNode &node, AABB ¢roidBounds) { +#ifndef __ARM_ARCH_ISA_A64 +#ifndef _MSC_VER + if (__builtin_cpu_supports("sse")) +#elif (defined(_M_AMD64) || defined(_M_X64)) + // SSE supported on Windows + if constexpr (true) +#endif + { + __m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f); + __m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f); + + for (int i = node.start; i < node.end; i += 2) { + int tri_idx1 = triangle_indices[i]; + const Triangle &leafTri1 = triangles[tri_idx1]; + // Check if the second actually exists in the node + __m128 v0, v1, v2, centroid; + if (i + 1 < node.end) { + int tri_idx2 = triangle_indices[i + 1]; + const Triangle leafTri2 = triangles[tri_idx2]; + + v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x, + leafTri2.v0.y); + v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x, + leafTri2.v1.y); + v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x, + leafTri2.v2.y); + centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y, + leafTri2.centroid.x, leafTri2.centroid.y); + } else { + // Otherwise do some duplicated work + v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x, + leafTri1.v0.y); + v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x, + leafTri1.v1.y); + v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x, + leafTri1.v2.y); + centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y, + leafTri1.centroid.x, leafTri1.centroid.y); + } + + min4 = _mm_min_ps(min4, v0); + max4 = _mm_max_ps(max4, v0); + min4 = _mm_min_ps(min4, v1); + max4 = _mm_max_ps(max4, v1); + min4 = _mm_min_ps(min4, v2); + max4 = _mm_max_ps(max4, v2); + cmin4 = _mm_min_ps(cmin4, centroid); + cmax4 = _mm_max_ps(cmax4, centroid); + } + + float min_values[4], max_values[4], cmin_values[4], cmax_values[4]; + _mm_store_ps(min_values, min4); + _mm_store_ps(max_values, max4); + _mm_store_ps(cmin_values, cmin4); + _mm_store_ps(cmax_values, cmax4); + + node.bbox.min.x = std::min(min_values[3], min_values[1]); + node.bbox.min.y = std::min(min_values[2], min_values[0]); + node.bbox.max.x = std::max(max_values[3], max_values[1]); + node.bbox.max.y = std::max(max_values[2], max_values[0]); + + centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]); + centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]); + centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]); + centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]); + } +#else + if constexpr (false) { + } +#endif + { + node.bbox.invalidate(); + centroidBounds.invalidate(); + + // Calculate the bounding box for the node + for (int i = node.start; i < node.end; ++i) { + int tri_idx = triangle_indices[i]; + const Triangle &tri = triangles[tri_idx]; + node.bbox.grow(tri.v0); + node.bbox.grow(tri.v1); + node.bbox.grow(tri.v2); + centroidBounds.grow(tri.centroid); + } + } +} + +void BVH::build(const tb_float2 *vertices, const tb_int3 *indices, + const int64_t &num_indices) { +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + // Create triangles + for (size_t i = 0; i < num_indices; ++i) { + tb_int3 idx = indices[i]; + triangles.push_back( + {vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast(i), + triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])}); + } + + // Initialize triangle_indices + triangle_indices.resize(triangles.size()); + std::iota(triangle_indices.begin(), triangle_indices.end(), 0); + + // Build BVH nodes + // Reserve extra capacity to fix windows specific crashes + nodes.reserve(triangles.size() * 2 + 1); + nodes.push_back({}); // Create the root node + root = 0; + + // Define a struct for queue entries + struct QueueEntry { + int node_idx; + int start; + int end; + }; + + // Queue for breadth-first traversal + std::queue node_queue; + node_queue.push({root, 0, (int)triangles.size()}); + + // Process each node in the queue + while (!node_queue.empty()) { + QueueEntry current = node_queue.front(); + node_queue.pop(); + + int node_idx = current.node_idx; + int start = current.start; + int end = current.end; + + BVHNode &node = nodes[node_idx]; + node.start = start; + node.end = end; + + // Calculate the bounding box for the node + AABB centroidBounds; + update_node_bounds(node, centroidBounds); + + // Determine the best split using SAH + int best_axis, best_pos; + + float splitCost = + find_best_split_plane(node, best_axis, best_pos, centroidBounds); + float nosplitCost = node.calculate_node_cost(); + + // Stop condition: if the best cost is greater than or equal to the parent's + // cost + if (splitCost >= nosplitCost) { + // Leaf node + node.left = node.right = -1; + continue; + } + + float scale = + BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]); + int i = node.start; + int j = node.end - 1; + + // Sort the triangle_indices in the range [start, end) based on the best + // axis + while (i <= j) { + // use the exact calculation we used for binning to prevent rare + // inaccuracies + int tri_idx = triangle_indices[i]; + tb_float2 tcentr = triangles[tri_idx].centroid; + int binIdx = std::min( + BINS - 1, + (int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale)); + if (binIdx < best_pos) + i++; + else + std::swap(triangle_indices[i], triangle_indices[j--]); + } + int leftCount = i - node.start; + if (leftCount == 0 || leftCount == node.num_triangles()) { + // Leaf node + node.left = node.right = -1; + continue; + } + + int mid = i; + + // Create and set left child + node.left = nodes.size(); + nodes.push_back({}); + node_queue.push({node.left, start, mid}); + + // Create and set right child + node = nodes[node_idx]; // Update the node - Potentially stale reference + node.right = nodes.size(); + nodes.push_back({}); + node_queue.push({node.right, mid, end}); + } +#ifdef TIMING + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl; +#endif +} + +// Utility function to clamp a value between a minimum and a maximum +float clamp(float val, float minVal, float maxVal) { + return std::min(std::max(val, minVal), maxVal); +} + +// Function to check if a point (xy) is inside a triangle defined by vertices +// v1, v2, v3 +bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2, + tb_float2 v3, float &u, float &v, float &w) { + // Vectors from v1 to v2, v3 and xy + tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y}; + tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y}; + tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y}; + + // Dot products of the vectors + float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y; + float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y; + float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y; + float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y; + float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y; + + // Calculate the barycentric coordinates + float denom = d00 * d11 - d01 * d01; + v = (d11 * d20 - d01 * d21) / denom; + w = (d00 * d21 - d01 * d20) / denom; + u = 1.0f - v - w; + + // Check if the point is inside the triangle + return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f); +} + +bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w, + int &index) const { + const int max_stack_size = 64; + int node_stack[max_stack_size]; + int stack_size = 0; + + node_stack[stack_size++] = root; + + while (stack_size > 0) { + int node_idx = node_stack[--stack_size]; + const BVHNode &node = nodes[node_idx]; + + if (node.is_leaf()) { + for (int i = node.start; i < node.end; ++i) { + const Triangle &tri = triangles[triangle_indices[i]]; + if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) { + index = tri.index; + return true; + } + } + } else { + if (nodes[node.right].bbox.overlaps(point)) { + if (stack_size < max_stack_size) { + node_stack[stack_size++] = node.right; + } else { + // Handle stack overflow + throw std::runtime_error("Node stack overflow"); + } + } + if (nodes[node.left].bbox.overlaps(point)) { + if (stack_size < max_stack_size) { + node_stack[stack_size++] = node.left; + } else { + // Handle stack overflow + throw std::runtime_error("Node stack overflow"); + } + } + } + } + + return false; +} + +torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices, + int64_t bake_resolution) { + int width = bake_resolution; + int height = bake_resolution; + int num_pixels = width * height; + torch::Tensor rast_result = torch::empty( + {bake_resolution, bake_resolution, 4}, + torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU)); + + float *rast_result_ptr = rast_result.contiguous().data_ptr(); + const tb_float2 *vertices = (tb_float2 *)uv.data_ptr(); + const tb_int3 *tris = (tb_int3 *)indices.data_ptr(); + + BVH bvh; + bvh.build(vertices, tris, indices.size(0)); + +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + +#pragma omp parallel for + for (int idx = 0; idx < num_pixels; ++idx) { + int x = idx / height; + int y = idx % height; + int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel + + tb_float2 pixel_coord = {float(y) / height, float(x) / width}; + pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f); + pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f); + + float u, v, w; + int triangle_idx; + if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) { + rast_result_ptr[idx_ + 0] = u; + rast_result_ptr[idx_ + 1] = v; + rast_result_ptr[idx_ + 2] = w; + rast_result_ptr[idx_ + 3] = static_cast(triangle_idx); + } else { + rast_result_ptr[idx_ + 0] = 0.0f; + rast_result_ptr[idx_ + 1] = 0.0f; + rast_result_ptr[idx_ + 2] = 0.0f; + rast_result_ptr[idx_ + 3] = -1.0f; + } + } + +#ifdef TIMING + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl; +#endif + return rast_result; +} + +torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices, + torch::Tensor rast) { +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + int height = rast.size(0); + int width = rast.size(1); + torch::Tensor pos_bake = torch::empty( + {height, width, 3}, + torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU)); + + const float *attr_ptr = attr.contiguous().data_ptr(); + const int *indices_ptr = indices.contiguous().data_ptr(); + const float *rast_ptr = rast.contiguous().data_ptr(); + float *output_ptr = pos_bake.contiguous().data_ptr(); + + int num_pixels = width * height; + +#pragma omp parallel for + for (int idx = 0; idx < num_pixels; ++idx) { + int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel) + tb_float3 barycentric = { + rast_ptr[idx_ + 0], + rast_ptr[idx_ + 1], + rast_ptr[idx_ + 2], + }; + int triangle_idx = static_cast(rast_ptr[idx_ + 3]); + + if (triangle_idx < 0) { + output_ptr[idx * 3 + 0] = 0.0f; + output_ptr[idx * 3 + 1] = 0.0f; + output_ptr[idx * 3 + 2] = 0.0f; + continue; + } + + tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0], + indices_ptr[3 * triangle_idx + 1], + indices_ptr[3 * triangle_idx + 2]}; + tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1], + attr_ptr[3 * triangle.x + 2]}; + tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1], + attr_ptr[3 * triangle.y + 2]}; + tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1], + attr_ptr[3 * triangle.z + 2]}; + + tb_float3 interpolated; + interpolated.x = + v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z; + interpolated.y = + v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z; + interpolated.z = + v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z; + + output_ptr[idx * 3 + 0] = interpolated.x; + output_ptr[idx * 3 + 1] = interpolated.y; + output_ptr[idx * 3 + 2] = interpolated.z; + } + +#ifdef TIMING + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl; +#endif + return pos_bake; +} + +// Registers _C as a Python extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +// Defines the operators +TORCH_LIBRARY(texture_baker_cpp, m) { + m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor"); + m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor"); +} + +// Registers CPP implementations +TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) { + m.impl("rasterize", &rasterize_cpu); + m.impl("interpolate", &interpolate_cpu); +} + +} // namespace texture_baker_cpp diff --git a/texture_baker/texture_baker/csrc/baker.h b/texture_baker/texture_baker/csrc/baker.h new file mode 100644 index 0000000000000000000000000000000000000000..a29ce2f6d0dd3e6f6b555cf45278b200b076e821 --- /dev/null +++ b/texture_baker/texture_baker/csrc/baker.h @@ -0,0 +1,203 @@ +#pragma once + +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__) +#define CUDA_ENABLED +#ifndef __METAL__ +#define CUDA_HOST_DEVICE __host__ __device__ +#define CUDA_DEVICE __device__ +#define METAL_CONSTANT_MEM +#define METAL_THREAD_MEM +#else +#define tb_float2 float2 +#define CUDA_HOST_DEVICE +#define CUDA_DEVICE +#define METAL_CONSTANT_MEM constant +#define METAL_THREAD_MEM thread +#endif +#else +#define CUDA_HOST_DEVICE +#define CUDA_DEVICE +#define METAL_CONSTANT_MEM +#define METAL_THREAD_MEM +#include +#include +#include +#endif + +namespace texture_baker_cpp { +// Structure to represent a 2D point or vector +#ifndef __METAL__ +union alignas(8) tb_float2 { + struct { + float x, y; + }; + + float data[2]; + + float &operator[](size_t idx) { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const tb_float2 &rhs) const { + return x == rhs.x && y == rhs.y; + } +}; + +union alignas(4) tb_float3 { + struct { + float x, y, z; + }; + + float data[3]; + + float &operator[](size_t idx) { + if (idx > 2) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 2) + throw std::runtime_error("bad index"); + return data[idx]; + } +}; + +union alignas(16) tb_float4 { + struct { + float x, y, z, w; + }; + + float data[4]; + + float &operator[](size_t idx) { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } +}; +#endif + +union alignas(4) tb_int3 { + struct { + int x, y, z; + }; + + int data[3]; +#ifndef __METAL__ + int &operator[](size_t idx) { + if (idx > 2) + throw std::runtime_error("bad index"); + return data[idx]; + } +#endif +}; + +// BVH structure to accelerate point-triangle intersection +struct alignas(16) AABB { + // Init bounding boxes with max/min + tb_float2 min = {FLT_MAX, FLT_MAX}; + tb_float2 max = {FLT_MIN, FLT_MIN}; + +#ifndef CUDA_ENABLED + // grow the AABB to include a point + void grow(const tb_float2 &p) { + min.x = std::min(min.x, p.x); + min.y = std::min(min.y, p.y); + max.x = std::max(max.x, p.x); + max.y = std::max(max.y, p.y); + } + + void grow(const AABB &b) { + if (b.min.x != FLT_MAX) { + grow(b.min); + grow(b.max); + } + } +#endif + + // Check if two AABBs overlap + bool overlaps(const METAL_THREAD_MEM AABB &other) const { + return min.x <= other.max.x && max.x >= other.min.x && + min.y <= other.max.y && max.y >= other.min.y; + } + + bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const { + return point.x >= min.x && point.x <= max.x && point.y >= min.y && + point.y <= max.y; + } + +#if defined(__NVCC__) + CUDA_DEVICE bool overlaps(const float2 &point) const { + return point.x >= min.x && point.x <= max.x && point.y >= min.y && + point.y <= max.y; + } +#endif + + // Initialize AABB to an invalid state + void invalidate() { + min = {FLT_MAX, FLT_MAX}; + max = {FLT_MIN, FLT_MIN}; + } + + // Calculate the area of the AABB + float area() const { + tb_float2 extent = {max.x - min.x, max.y - min.y}; + return extent.x * extent.y; + } +}; + +struct BVHNode { + AABB bbox; + int start, end; + int left, right; + + int num_triangles() const { return end - start; } + + CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; } + + float calculate_node_cost() { + float area = bbox.area(); + return num_triangles() * area; + } +}; + +struct Triangle { + tb_float2 v0, v1, v2; + int index; + tb_float2 centroid; +}; + +#ifndef __METAL__ +struct BVH { + std::vector nodes; + std::vector triangles; + std::vector triangle_indices; + int root; + + void build(const tb_float2 *vertices, const tb_int3 *indices, + const int64_t &num_indices); + bool intersect(const tb_float2 &point, float &u, float &v, float &w, + int &index) const; + + void update_node_bounds(BVHNode &node, AABB ¢roidBounds); + float find_best_split_plane(const BVHNode &node, int &best_axis, + int &best_pos, AABB ¢roidBounds); +}; +#endif + +} // namespace texture_baker_cpp diff --git a/texture_baker/texture_baker/csrc/baker_kernel.cu b/texture_baker/texture_baker/csrc/baker_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f21f0675bf8eb43ceae0115390b7cbec0937b8de --- /dev/null +++ b/texture_baker/texture_baker/csrc/baker_kernel.cu @@ -0,0 +1,301 @@ +#include +#include +#include +#include + +#include "baker.h" + +// #define TIMING + +#define STRINGIFY(x) #x +#define STR(x) STRINGIFY(x) +#define FILE_LINE __FILE__ ":" STR(__LINE__) +#define CUDA_CHECK_THROW(x) \ + do { \ + cudaError_t _result = x; \ + if (_result != cudaSuccess) \ + throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \ + } while(0) + +namespace texture_baker_cpp +{ + + __device__ float3 operator+(const float3 &a, const float3 &b) + { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + + // xy: 2D test position + // v1: vertex position 1 + // v2: vertex position 2 + // v3: vertex position 3 + // + __forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w) + { + // Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3. + // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w. + float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y); + float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y); + float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y); + + float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y; + float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y; + float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y; + float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y; + float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y; + + float denom = d00 * d11 - d01 * d01; + v = (d11 * d20 - d01 * d21) / denom; + w = (d00 * d21 - d01 * d20) / denom; + u = 1.0f - v - w; + + return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f); + } + + __global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height) + { + // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx) + //int idx = x * width + y; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int x = idx / width; + int y = idx % width; + + if (x >= width || y >= height) + return; + + float4 barycentric = rast[idx]; + int triangle_idx = int(barycentric.w); + + if (triangle_idx < 0) + { + output[idx] = make_float3(0.0f, 0.0f, 0.0f); + return; + } + + float3 v1 = attr[indices[triangle_idx].x]; + float3 v2 = attr[indices[triangle_idx].y]; + float3 v3 = attr[indices[triangle_idx].z]; + + output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x) + + make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y) + + make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z); + } + + __device__ bool bvh_intersect( + const BVHNode* __restrict__ nodes, + const Triangle* __restrict__ triangles, + const int* __restrict__ triangle_indices, + const int root, + const float2 &point, + float &u, float &v, float &w, + int &index) + { + constexpr int max_stack_size = 64; + int node_stack[max_stack_size]; + int stack_size = 0; + + node_stack[stack_size++] = root; + + while (stack_size > 0) + { + int node_idx = node_stack[--stack_size]; + const BVHNode &node = nodes[node_idx]; + + if (node.is_leaf()) + { + for (int i = node.start; i < node.end; ++i) + { + const Triangle &tri = triangles[triangle_indices[i]]; + if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) + { + index = tri.index; + return true; + } + } + } + else + { + if (nodes[node.right].bbox.overlaps(point)) + { + if (stack_size < max_stack_size) + { + node_stack[stack_size++] = node.right; + } + else + { + // Handle stack overflow + // Make sure NDEBUG is not defined (see setup.py) + assert(0 && "Node stack overflow"); + } + } + if (nodes[node.left].bbox.overlaps(point)) + { + if (stack_size < max_stack_size) + { + node_stack[stack_size++] = node.left; + } + else + { + // Handle stack overflow + // Make sure NDEBUG is not defined (see setup.py) + assert(0 && "Node stack overflow"); + } + } + } + } + + return false; + } + + __global__ void kernel_bake_uv( + float2* __restrict__ uv, + int3* __restrict__ indices, + float4* __restrict__ output, + const BVHNode* __restrict__ nodes, + const Triangle* __restrict__ triangles, + const int* __restrict__ triangle_indices, + const int root, + const int width, + const int height, + const int num_indices) + { + //int idx = x * width + y; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int x = idx / width; + int y = idx % width; + + if (y >= width || x >= height) + return; + + // We index x,y but the original coords are HW. So swap them + float2 pixel_coord = make_float2(float(y) / height, float(x) / width); + pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f); + pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f); + + float u, v, w; + int triangle_idx; + bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx); + + if (hit) + { + output[idx] = make_float4(u, v, w, float(triangle_idx)); + return; + } + + output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f); + } + + torch::Tensor rasterize_gpu( + torch::Tensor uv, + torch::Tensor indices, + int64_t bake_resolution) + { +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + constexpr int block_size = 16 * 16; + int grid_size = bake_resolution * bake_resolution / block_size; + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(grid_size, 1, 1); + + int num_indices = indices.size(0); + + int width = bake_resolution; + int height = bake_resolution; + + // Step 1: create an empty tensor to store the output. + torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + + auto vertices_cpu = uv.contiguous().cpu(); + auto indices_cpu = indices.contiguous().cpu(); + + const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr(); + const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr(); + + BVH bvh; + bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0)); + + BVHNode *nodes_gpu = nullptr; + Triangle *triangles_gpu = nullptr; + int *triangle_indices_gpu = nullptr; + const int bvh_root = bvh.root; + cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + + CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream)); + CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream)); + CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream)); + + CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream)); + CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream)); + CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream)); + + kernel_bake_uv<<>>( + (float2 *)uv.contiguous().data_ptr(), + (int3 *)indices.contiguous().data_ptr(), + (float4 *)rast_result.contiguous().data_ptr(), + nodes_gpu, + triangles_gpu, + triangle_indices_gpu, + bvh_root, + width, + height, + num_indices); + + CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream)); + CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream)); + CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream)); + +#ifdef TIMING + CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream)); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl; +#endif + return rast_result; + } + + torch::Tensor interpolate_gpu( + torch::Tensor attr, + torch::Tensor indices, + torch::Tensor rast) + { +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + constexpr int block_size = 16 * 16; + int grid_size = rast.size(0) * rast.size(0) / block_size; + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(grid_size, 1, 1); + + // Step 1: create an empty tensor to store the output. + torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + + int width = rast.size(0); + int height = rast.size(1); + + cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + + kernel_interpolate<<>>( + (float3 *)attr.contiguous().data_ptr(), + (int3 *)indices.contiguous().data_ptr(), + (float4 *)rast.contiguous().data_ptr(), + (float3 *)pos_bake.contiguous().data_ptr(), + width, + height); +#ifdef TIMING + CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream)); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl; +#endif + return pos_bake; + } + + // Registers CUDA implementations + TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m) + { + m.impl("rasterize", &rasterize_gpu); + m.impl("interpolate", &interpolate_gpu); + } + +} diff --git a/texture_baker/texture_baker/csrc/baker_kernel.metal b/texture_baker/texture_baker/csrc/baker_kernel.metal new file mode 100644 index 0000000000000000000000000000000000000000..dff0a9413d5bfa6e6d03f3c2deec3890c1818bb5 --- /dev/null +++ b/texture_baker/texture_baker/csrc/baker_kernel.metal @@ -0,0 +1,170 @@ +#include +using namespace metal; + +// This header is inlined manually +//#include "baker.h" + +// Use the texture_baker_cpp so it can use the classes from baker.h +using namespace texture_baker_cpp; + +// Utility function to compute barycentric coordinates +bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) { + float2 v1v2 = v2 - v1; + float2 v1v3 = v3 - v1; + float2 xyv1 = xy - v1; + + float d00 = dot(v1v2, v1v2); + float d01 = dot(v1v2, v1v3); + float d11 = dot(v1v3, v1v3); + float d20 = dot(xyv1, v1v2); + float d21 = dot(xyv1, v1v3); + + float denom = d00 * d11 - d01 * d01; + v = (d11 * d20 - d01 * d21) / denom; + w = (d00 * d21 - d01 * d20) / denom; + u = 1.0f - v - w; + + return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f); +} + +// Kernel function for interpolation +kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]], + constant packed_int3 *indices [[buffer(1)]], + constant packed_float4 *rast [[buffer(2)]], + device packed_float3 *output [[buffer(3)]], + constant int &width [[buffer(4)]], + constant int &height [[buffer(5)]], + uint3 blockIdx [[threadgroup_position_in_grid]], + uint3 threadIdx [[thread_position_in_threadgroup]], + uint3 blockDim [[threads_per_threadgroup]]) +{ + // Calculate global position using threadgroup and thread positions + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int idx = y * width + x; + float4 barycentric = rast[idx]; + int triangle_idx = int(barycentric.w); + + if (triangle_idx < 0) { + output[idx] = float3(0.0f, 0.0f, 0.0f); + return; + } + + float3 v1 = attr[indices[triangle_idx].x]; + float3 v2 = attr[indices[triangle_idx].y]; + float3 v3 = attr[indices[triangle_idx].z]; + + output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z; +} + +bool bvh_intersect( + constant BVHNode* nodes, + constant Triangle* triangles, + constant int* triangle_indices, + const thread int root, + const thread float2 &point, + thread float &u, thread float &v, thread float &w, + thread int &index) +{ + const int max_stack_size = 64; + thread int node_stack[max_stack_size]; + int stack_size = 0; + + node_stack[stack_size++] = root; + + while (stack_size > 0) + { + int node_idx = node_stack[--stack_size]; + BVHNode node = nodes[node_idx]; + + if (node.is_leaf()) + { + for (int i = node.start; i < node.end; ++i) + { + constant Triangle &tri = triangles[triangle_indices[i]]; + if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) + { + index = tri.index; + return true; + } + } + } + else + { + BVHNode test_node = nodes[node.right]; + if (test_node.bbox.overlaps(point)) + { + if (stack_size < max_stack_size) + { + node_stack[stack_size++] = node.right; + } + else + { + // Handle stack overflow + // Sadly, metal doesn't support asserts (but you could try enabling metal validation layers) + return false; + } + } + test_node = nodes[node.left]; + if (test_node.bbox.overlaps(point)) + { + if (stack_size < max_stack_size) + { + node_stack[stack_size++] = node.left; + } + else + { + // Handle stack overflow + return false; + } + } + } + } + + return false; +} + + +// Kernel function for baking UV +kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]], + constant packed_int3 *indices [[buffer(1)]], + device packed_float4 *output [[buffer(2)]], + constant BVHNode *nodes [[buffer(3)]], + constant Triangle *triangles [[buffer(4)]], + constant int *triangle_indices [[buffer(5)]], + constant int &root [[buffer(6)]], + constant int &width [[buffer(7)]], + constant int &height [[buffer(8)]], + constant int &num_indices [[buffer(9)]], + uint3 blockIdx [[threadgroup_position_in_grid]], + uint3 threadIdx [[thread_position_in_threadgroup]], + uint3 blockDim [[threads_per_threadgroup]]) +{ + // Calculate global position using threadgroup and thread positions + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + + if (x >= width || y >= height) return; + + int idx = x * width + y; + + // Swap original coordinates + float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width)); + pixel_coord = clamp(pixel_coord, 0.0f, 1.0f); + pixel_coord.y = 1.0f - pixel_coord.y; + + float u, v, w; + int triangle_idx; + bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx); + + if (hit) { + output[idx] = float4(u, v, w, float(triangle_idx)); + return; + } + + output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f); +} diff --git a/texture_baker/texture_baker/csrc/baker_kernel.mm b/texture_baker/texture_baker/csrc/baker_kernel.mm new file mode 100644 index 0000000000000000000000000000000000000000..ee406a215cba22ae578f04a9721fb66641a0430f --- /dev/null +++ b/texture_baker/texture_baker/csrc/baker_kernel.mm @@ -0,0 +1,260 @@ +#include +#include +#include +#include "baker.h" + +#import +#import +#include + +// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`. +static inline id getMTLBufferStorage(const torch::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +// Helper function to create a compute pipeline state object (PSO). +static inline id createComputePipelineState(id device, NSString* fullSource, std::string kernel_name) { + NSError *error = nil; + + // Load the custom kernel shader. + MTLCompileOptions *options = [[MTLCompileOptions alloc] init]; + // Add the preprocessor macro "__METAL__" + options.preprocessorMacros = @{@"__METAL__": @""}; + id customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error]; + TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String); + + id customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]]; + TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str()); + + id pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + return pso; +} + +std::filesystem::path get_extension_path() { + // Ensure the GIL is held before calling any Python C API function + PyGILState_STATE gstate = PyGILState_Ensure(); + + const char* module_name = "texture_baker"; + + // Import the module by name + PyObject* module = PyImport_ImportModule(module_name); + if (!module) { + PyGILState_Release(gstate); + throw std::runtime_error("Could not import the module: " + std::string(module_name)); + } + + // Get the filename of the module + PyObject* filename_obj = PyModule_GetFilenameObject(module); + if (filename_obj) { + std::string path = PyUnicode_AsUTF8(filename_obj); + Py_DECREF(filename_obj); + PyGILState_Release(gstate); + + // Get the directory part of the path (removing the __init__.py) + std::filesystem::path module_path = std::filesystem::path(path).parent_path(); + + // Append the 'csrc' directory to the path + module_path /= "csrc"; + + return module_path; + } else { + PyGILState_Release(gstate); + throw std::runtime_error("Could not retrieve the module filename."); + } +} + +NSString *get_shader_sources_as_string() +{ + const std::filesystem::path csrc_path = get_extension_path(); + const std::string shader_path = (csrc_path / "baker_kernel.metal").string(); + const std::string shader_header_path = (csrc_path / "baker.h").string(); + // Load the Metal shader from the specified path + NSError *error = nil; + + NSString* shaderHeaderSource = [ + NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()] + encoding:NSUTF8StringEncoding + error:&error]; + if (error) { + throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String)); + } + + NSString* shaderSource = [ + NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()] + encoding:NSUTF8StringEncoding + error:&error]; + if (error) { + throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String)); + } + + NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource]; + + return fullSource; +} + +namespace texture_baker_cpp +{ + torch::Tensor rasterize_gpu( + torch::Tensor uv, + torch::Tensor indices, + int64_t bake_resolution) + { + TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor"); + TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous"); + TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous"); + + TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type()); + TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type()); + + torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous(); + + @autoreleasepool { + auto vertices_cpu = uv.contiguous().cpu(); + auto indices_cpu = indices.contiguous().cpu(); + + const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr(); + const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr(); + + BVH bvh; + bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0)); + + id device = MTLCreateSystemDefaultDevice(); + + NSString *fullSource = get_shader_sources_as_string(); + + // Create a compute pipeline state object using the helper function + id bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv"); + + // Get a reference to the command buffer for the MPS stream. + id commandBuffer = torch::mps::get_command_buffer(); + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); + + // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU. + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^(){ + // Start a compute pass. + id computeEncoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); + + // Get Metal buffers directly from PyTorch tensors + auto uv_buf = getMTLBufferStorage(uv.contiguous()); + auto indices_buf = getMTLBufferStorage(indices.contiguous()); + auto rast_result_buf = getMTLBufferStorage(rast_result); + + const int width = bake_resolution; + const int height = bake_resolution; + const int num_indices = indices.size(0); + const int bvh_root = bvh.root; + + // Wrap the existing CPU memory in Metal buffers with shared memory + id nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil]; + id trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil]; + id triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil]; + + [computeEncoder setComputePipelineState:bake_uv_PSO]; + [computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0]; + [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1]; + [computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2]; + [computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3]; + [computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4]; + [computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5]; + [computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int) atIndex:7]; + [computeEncoder setBytes:&height length:sizeof(int) atIndex:8]; + [computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9]; + + // Calculate a thread group size. + int block_size = 16; + MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size + MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1); + + // Encode the compute command. + [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize]; + [computeEncoder endEncoding]; + + // Commit the work. + torch::mps::commit(); + }); + } + + return rast_result; + } + + torch::Tensor interpolate_gpu( + torch::Tensor attr, + torch::Tensor indices, + torch::Tensor rast) + { + TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous"); + TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous"); + TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous"); + + torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous(); + std::filesystem::path csrc_path = get_extension_path(); + + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + + NSString *fullSource = get_shader_sources_as_string(); + // Create a compute pipeline state object using the helper function + id interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate"); + + // Get a reference to the command buffer for the MPS stream. + id commandBuffer = torch::mps::get_command_buffer(); + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); + + // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU. + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^(){ + // Start a compute pass. + id computeEncoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); + + // Get Metal buffers directly from PyTorch tensors + auto attr_buf = getMTLBufferStorage(attr.contiguous()); + auto indices_buf = getMTLBufferStorage(indices.contiguous()); + auto rast_buf = getMTLBufferStorage(rast.contiguous()); + auto pos_bake_buf = getMTLBufferStorage(pos_bake); + + int width = rast.size(0); + int height = rast.size(1); + + [computeEncoder setComputePipelineState:interpolate_PSO]; + [computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0]; + [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1]; + [computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2]; + [computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3]; + [computeEncoder setBytes:&width length:sizeof(int) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int) atIndex:5]; + + // Calculate a thread group size. + + int block_size = 16; + MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size + MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1); + + // Encode the compute command. + [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize]; + + [computeEncoder endEncoding]; + + // Commit the work. + torch::mps::commit(); + }); + } + + return pos_bake; + } + + // Registers MPS implementations + TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m) + { + m.impl("rasterize", &rasterize_gpu); + m.impl("interpolate", &interpolate_gpu); + } + +} diff --git a/uv_unwrapper/README.md b/uv_unwrapper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/uv_unwrapper/requirements.txt b/uv_unwrapper/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..af3149eb479c47955cf0d40d253890baa18d2f54 --- /dev/null +++ b/uv_unwrapper/requirements.txt @@ -0,0 +1,2 @@ +torch +numpy diff --git a/uv_unwrapper/setup.py b/uv_unwrapper/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..62cc23e2e6953b559502151b230c349163293c74 --- /dev/null +++ b/uv_unwrapper/setup.py @@ -0,0 +1,79 @@ +import glob +import os + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, +) + +library_name = "uv_unwrapper" + + +def get_extensions(): + debug_mode = os.getenv("DEBUG", "0") == "1" + if debug_mode: + print("Compiling in debug mode") + + is_mac = True if torch.backends.mps.is_available() else False + use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1" + extension = CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + ("-Xclang " if is_mac else "") + "-fopenmp", + ] + + ["-march=native"] + if use_native_arch + else [], + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["cxx"].append("-UNDEBUG") + extra_link_args.extend(["-O0", "-g"]) + + define_macros = [] + extensions = [] + + this_dir = os.path.dirname(os.path.curdir) + sources = glob.glob( + os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True + ) + + if len(sources) == 0: + print("No source files found for extension, skipping extension compilation") + return None + + extensions.append( + extension( + name=f"{library_name}._C", + sources=sources, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + libraries=["c10", "torch", "torch_cpu", "torch_python"] + ["omp"] + if is_mac + else [], + ) + ) + + print(extensions) + + return extensions + + +setup( + name=library_name, + version="0.0.1", + packages=find_packages(), + ext_modules=get_extensions(), + install_requires=[], + description="Box projection based UV unwrapper", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + cmdclass={"build_ext": BuildExtension}, +) diff --git a/uv_unwrapper/uv_unwrapper/__init__.py b/uv_unwrapper/uv_unwrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f56e161df5d9e5658fe6d87ccf207b9874e94481 --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/__init__.py @@ -0,0 +1,6 @@ +import torch # noqa: F401 + +from . import _C # noqa: F401 +from .unwrap import Unwrapper + +__all__ = ["Unwrapper"] diff --git a/uv_unwrapper/uv_unwrapper/csrc/bvh.cpp b/uv_unwrapper/uv_unwrapper/csrc/bvh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25468f38740ce36e9343d13a311eada63a211e56 --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/bvh.cpp @@ -0,0 +1,381 @@ + + +#include "bvh.h" +#include "common.h" +#include +#include +#include +#include +#include + +namespace UVUnwrapper { +BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) { + // Copty tri to triangle + triangle = new Triangle[num_indices]; + memcpy(triangle, tri, num_indices * sizeof(Triangle)); + + // Copy actual_idx to actualIdx + actualIdx = new int[num_indices]; + memcpy(actualIdx, actual_idx, num_indices * sizeof(int)); + + triIdx = new int[num_indices]; + triCount = num_indices; + + bvhNode = new BVHNode[triCount * 2 + 64]; + nodesUsed = 2; + memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode)); + + // populate triangle index array + for (int i = 0; i < triCount; i++) + triIdx[i] = i; + + BVHNode &root = bvhNode[0]; + + root.start = 0, root.end = triCount; + AABB centroidBounds; + UpdateNodeBounds(0, centroidBounds); + + // subdivide recursively + Subdivide(0, nodesUsed, centroidBounds); +} + +BVH::BVH(const BVH &other) + : BVH(other.triangle, other.triIdx, other.triCount) {} + +BVH::BVH(BVH &&other) noexcept // move constructor + : triIdx(std::exchange(other.triIdx, nullptr)), + actualIdx(std::exchange(other.actualIdx, nullptr)), + triangle(std::exchange(other.triangle, nullptr)), + bvhNode(std::exchange(other.bvhNode, nullptr)) {} + +BVH &BVH::operator=(const BVH &other) // copy assignment +{ + return *this = BVH(other); +} + +BVH &BVH::operator=(BVH &&other) noexcept // move assignment +{ + std::swap(triIdx, other.triIdx); + std::swap(actualIdx, other.actualIdx); + std::swap(triangle, other.triangle); + std::swap(bvhNode, other.bvhNode); + std::swap(triCount, other.triCount); + std::swap(nodesUsed, other.nodesUsed); + return *this; +} + +BVH::~BVH() { + if (triIdx) + delete[] triIdx; + if (triangle) + delete[] triangle; + if (actualIdx) + delete[] actualIdx; + if (bvhNode) + delete[] bvhNode; +} + +void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds) { + BVHNode &node = bvhNode[nodeIdx]; +#ifndef __ARM_ARCH_ISA_A64 +#ifndef _MSC_VER + if (__builtin_cpu_supports("sse")) +#elif (defined(_M_AMD64) || defined(_M_X64)) + // SSE supported on Windows + if constexpr (true) +#endif + { + __m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN); + __m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN); + for (int i = node.start; i < node.end; i += 2) { + Triangle &leafTri1 = triangle[triIdx[i]]; + __m128 v0, v1, v2, centroid; + if (i + 1 < node.end) { + const Triangle leafTri2 = triangle[triIdx[i + 1]]; + + v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x, + leafTri2.v0.y); + v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x, + leafTri2.v1.y); + v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x, + leafTri2.v2.y); + centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y, + leafTri2.centroid.x, leafTri2.centroid.y); + } else { + // Otherwise do some duplicated work + v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x, + leafTri1.v0.y); + v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x, + leafTri1.v1.y); + v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x, + leafTri1.v2.y); + centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y, + leafTri1.centroid.x, leafTri1.centroid.y); + } + + min4 = _mm_min_ps(min4, v0); + max4 = _mm_max_ps(max4, v0); + min4 = _mm_min_ps(min4, v1); + max4 = _mm_max_ps(max4, v1); + min4 = _mm_min_ps(min4, v2); + max4 = _mm_max_ps(max4, v2); + cmin4 = _mm_min_ps(cmin4, centroid); + cmax4 = _mm_max_ps(cmax4, centroid); + } + float min_values[4], max_values[4], cmin_values[4], cmax_values[4]; + _mm_store_ps(min_values, min4); + _mm_store_ps(max_values, max4); + _mm_store_ps(cmin_values, cmin4); + _mm_store_ps(cmax_values, cmax4); + + node.bbox.min.x = std::min(min_values[3], min_values[1]); + node.bbox.min.y = std::min(min_values[2], min_values[0]); + node.bbox.max.x = std::max(max_values[3], max_values[1]); + node.bbox.max.y = std::max(max_values[2], max_values[0]); + + centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]); + centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]); + centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]); + centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]); + } +#else + if constexpr (false) { + } +#endif + else { + node.bbox.invalidate(); + centroidBounds.invalidate(); + + // Calculate the bounding box for the node + for (int i = node.start; i < node.end; ++i) { + const Triangle &tri = triangle[triIdx[i]]; + node.bbox.grow(tri.v0); + node.bbox.grow(tri.v1); + node.bbox.grow(tri.v2); + centroidBounds.grow(tri.centroid); + } + } +} + +void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr, + AABB &rootCentroidBounds) { + // Create a queue for the nodes to be subdivided + std::queue> nodeQueue; + nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds)); + + while (!nodeQueue.empty()) { + // Get the next node to process from the queue + auto [node_idx, centroidBounds] = nodeQueue.front(); + nodeQueue.pop(); + BVHNode &node = bvhNode[node_idx]; + + // Check if left is -1 and right not or vice versa + + int axis, splitPos; + float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds); + + if (cost >= node.calculate_node_cost()) { + node.left = node.right = -1; + continue; // Move on to the next node in the queue + } + + int i = node.start; + int j = node.end - 1; + float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]); + while (i <= j) { + int binIdx = + std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] - + centroidBounds.min[axis]) * + scale)); + if (binIdx < splitPos) + i++; + else + std::swap(triIdx[i], triIdx[j--]); + } + + int leftCount = i - node.start; + if (leftCount == 0 || leftCount == (int)node.num_triangles()) { + node.left = node.right = -1; + continue; // Move on to the next node in the queue + } + + int mid = i; + + // Create child nodes + int leftChildIdx = nodePtr++; + int rightChildIdx = nodePtr++; + bvhNode[leftChildIdx].start = node.start; + bvhNode[leftChildIdx].end = mid; + bvhNode[rightChildIdx].start = mid; + bvhNode[rightChildIdx].end = node.end; + node.left = leftChildIdx; + node.right = rightChildIdx; + + // Update the bounds for the child nodes and push them onto the queue + UpdateNodeBounds(leftChildIdx, centroidBounds); + nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds)); + + UpdateNodeBounds(rightChildIdx, centroidBounds); + nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds)); + } +} + +float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos, + AABB ¢roidBounds) { + float best_cost = FLT_MAX; + + for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y + { + float boundsMin = centroidBounds.min[axis]; + float boundsMax = centroidBounds.max[axis]; + // Or floating point precision + if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) { + continue; + } + + // populate the bins + float scale = BINS / (boundsMax - boundsMin); + float leftCountArea[BINS - 1], rightCountArea[BINS - 1]; + int leftSum = 0, rightSum = 0; +#ifndef __ARM_ARCH_ISA_A64 +#ifndef _MSC_VER + if (__builtin_cpu_supports("sse")) +#elif (defined(_M_AMD64) || defined(_M_X64)) + // SSE supported on Windows + if constexpr (true) +#endif + { + __m128 min4[BINS], max4[BINS]; + unsigned int count[BINS]; + for (unsigned int i = 0; i < BINS; i++) + min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN), + count[i] = 0; + for (int i = node.start; i < node.end; i++) { + Triangle &tri = triangle[triIdx[i]]; + int binIdx = + std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale)); + count[binIdx]++; + + __m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f); + __m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f); + __m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f); + min4[binIdx] = _mm_min_ps(min4[binIdx], v0); + max4[binIdx] = _mm_max_ps(max4[binIdx], v0); + min4[binIdx] = _mm_min_ps(min4[binIdx], v1); + max4[binIdx] = _mm_max_ps(max4[binIdx], v1); + min4[binIdx] = _mm_min_ps(min4[binIdx], v2); + max4[binIdx] = _mm_max_ps(max4[binIdx], v2); + } + // gather data for the 7 planes between the 8 bins + __m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4; + __m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4; + for (int i = 0; i < BINS - 1; i++) { + leftSum += count[i]; + rightSum += count[BINS - 1 - i]; + leftMin4 = _mm_min_ps(leftMin4, min4[i]); + rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]); + leftMax4 = _mm_max_ps(leftMax4, max4[i]); + rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]); + float le[4], re[4]; + _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4)); + _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4)); + // SSE order goes from back to front + leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation + rightCountArea[BINS - 2 - i] = + rightSum * (re[2] * re[3]); // 2D area calculation + } + } +#else + if constexpr (false) { + } +#endif + else { + struct Bin { + AABB bounds; + int triCount = 0; + } bin[BINS]; + for (int i = node.start; i < node.end; i++) { + Triangle &tri = triangle[triIdx[i]]; + int binIdx = + std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale)); + bin[binIdx].triCount++; + bin[binIdx].bounds.grow(tri.v0); + bin[binIdx].bounds.grow(tri.v1); + bin[binIdx].bounds.grow(tri.v2); + } + // gather data for the 7 planes between the 8 bins + AABB leftBox, rightBox; + for (int i = 0; i < BINS - 1; i++) { + leftSum += bin[i].triCount; + leftBox.grow(bin[i].bounds); + leftCountArea[i] = leftSum * leftBox.area(); + rightSum += bin[BINS - 1 - i].triCount; + rightBox.grow(bin[BINS - 1 - i].bounds); + rightCountArea[BINS - 2 - i] = rightSum * rightBox.area(); + } + } + + // calculate SAH cost for the 7 planes + scale = (boundsMax - boundsMin) / BINS; + for (int i = 0; i < BINS - 1; i++) { + const float planeCost = leftCountArea[i] + rightCountArea[i]; + if (planeCost < best_cost) + best_axis = axis, best_pos = i + 1, best_cost = planeCost; + } + } + return best_cost; +} + +std::vector BVH::Intersect(Triangle &tri_intersect) { + /** + * @brief Intersect a triangle with the BVH + * + * @param triangle the triangle to intersect + * + * @return -1 for no intersection, the index of the intersected triangle + * otherwise + */ + + const int max_stack_size = 64; + int node_stack[max_stack_size]; + int stack_size = 0; + std::vector intersected_triangles; + + node_stack[stack_size++] = 0; // Start with the root node (index 0) + while (stack_size > 0) { + int node_idx = node_stack[--stack_size]; + const BVHNode &node = bvhNode[node_idx]; + if (node.is_leaf()) { + for (int i = node.start; i < node.end; ++i) { + const Triangle &tri = triangle[triIdx[i]]; + // Check that the triangle is not the same as the intersected triangle + if (tri == tri_intersect) + continue; + if (tri_intersect.overlaps(tri)) { + intersected_triangles.push_back(actualIdx[triIdx[i]]); + } + } + } else { + // Check right child first + if (bvhNode[node.right].bbox.overlaps(tri_intersect)) { + if (stack_size < max_stack_size) { + node_stack[stack_size++] = node.right; + } else { + throw std::runtime_error("Node stack overflow"); + } + } + + // Check left child + if (bvhNode[node.left].bbox.overlaps(tri_intersect)) { + if (stack_size < max_stack_size) { + node_stack[stack_size++] = node.left; + } else { + throw std::runtime_error("Node stack overflow"); + } + } + } + } + return intersected_triangles; // Return all intersected triangle indices +} + +} // namespace UVUnwrapper diff --git a/uv_unwrapper/uv_unwrapper/csrc/bvh.h b/uv_unwrapper/uv_unwrapper/csrc/bvh.h new file mode 100644 index 0000000000000000000000000000000000000000..91f48f4f735ff7f260e6023044f44137efdcc568 --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/bvh.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#ifndef __ARM_ARCH_ISA_A64 +#include +#endif +#include +#include + +#include "common.h" +#include "intersect.h" +/** + * Based on https://github.com/jbikker/bvh_article released under the unlicense. + */ + +// bin count for binned BVH building +#define BINS 8 + +namespace UVUnwrapper { +// minimalist triangle struct +struct alignas(32) Triangle { + uv_float2 v0; + uv_float2 v1; + uv_float2 v2; + uv_float2 centroid; + + bool overlaps(const Triangle &other) { + // return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2); + return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1, + other.v2); + } + + bool operator==(const Triangle &rhs) const { + return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2; + } +}; + +// minimalist AABB struct with grow functionality +struct alignas(16) AABB { + // Init bounding boxes with max/min + uv_float2 min = {FLT_MAX, FLT_MAX}; + uv_float2 max = {FLT_MIN, FLT_MIN}; + + void grow(const uv_float2 &p) { + min.x = std::min(min.x, p.x); + min.y = std::min(min.y, p.y); + max.x = std::max(max.x, p.x); + max.y = std::max(max.y, p.y); + } + + void grow(const AABB &b) { + if (b.min.x != FLT_MAX) { + grow(b.min); + grow(b.max); + } + } + + bool overlaps(const Triangle &tri) { + return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2); + } + + float area() const { + uv_float2 extent = {max.x - min.x, max.y - min.y}; + return extent.x * extent.y; + } + + void invalidate() { + min = {FLT_MAX, FLT_MAX}; + max = {FLT_MIN, FLT_MIN}; + } +}; + +// 32-byte BVH node struct +struct alignas(32) BVHNode { + AABB bbox; // 16 + int start = 0, end = 0; // 8 + int left, right; + + int num_triangles() const { return end - start; } + + bool is_leaf() const { return left == -1 && right == -1; } + + float calculate_node_cost() { + float area = bbox.area(); + return num_triangles() * area; + } +}; + +class BVH { +public: + BVH() = default; + BVH(BVH &&other) noexcept; + BVH(const BVH &other); + BVH &operator=(const BVH &other); + BVH &operator=(BVH &&other) noexcept; + BVH(Triangle *tri, int *actual_idx, const size_t &num_indices); + ~BVH(); + + std::vector Intersect(Triangle &triangle); + +private: + void Subdivide(unsigned int node_idx, unsigned int &nodePtr, + AABB ¢roidBounds); + void UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds); + float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos, + AABB ¢roidBounds); + +public: + int *triIdx = nullptr; + int *actualIdx = nullptr; + unsigned int triCount; + unsigned int nodesUsed; + BVHNode *bvhNode = nullptr; + Triangle *triangle = nullptr; +}; + +} // namespace UVUnwrapper diff --git a/uv_unwrapper/uv_unwrapper/csrc/common.h b/uv_unwrapper/uv_unwrapper/csrc/common.h new file mode 100644 index 0000000000000000000000000000000000000000..5b6adb02af038acb68d2de5279e2c3b299b9d69e --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/common.h @@ -0,0 +1,493 @@ +#pragma once + +#include +#include +#include +#include + +const float EPSILON = 1e-7f; + +// Structure to represent a 2D point or vector +union alignas(8) uv_float2 { + struct { + float x, y; + }; + + float data[2]; + + float &operator[](size_t idx) { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_float2 &rhs) const { + return x == rhs.x && y == rhs.y; + } +}; + +// Do not align as this is specifically tweaked for BVHNode +union uv_float3 { + struct { + float x, y, z; + }; + + float data[3]; + + float &operator[](size_t idx) { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_float3 &rhs) const { + return x == rhs.x && y == rhs.y && z == rhs.z; + } +}; + +union alignas(16) uv_float4 { + struct { + float x, y, z, w; + }; + + float data[4]; + + float &operator[](size_t idx) { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const float &operator[](size_t idx) const { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_float4 &rhs) const { + return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w; + } +}; + +union alignas(8) uv_int2 { + struct { + int x, y; + }; + + int data[2]; + + int &operator[](size_t idx) { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const int &operator[](size_t idx) const { + if (idx > 1) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; } +}; + +union alignas(4) uv_int3 { + struct { + int x, y, z; + }; + + int data[3]; + + int &operator[](size_t idx) { + if (idx > 2) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const int &operator[](size_t idx) const { + if (idx > 2) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_int3 &rhs) const { + return x == rhs.x && y == rhs.y && z == rhs.z; + } +}; + +union alignas(16) uv_int4 { + struct { + int x, y, z, w; + }; + + int data[4]; + + int &operator[](size_t idx) { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + const int &operator[](size_t idx) const { + if (idx > 3) + throw std::runtime_error("bad index"); + return data[idx]; + } + + bool operator==(const uv_int4 &rhs) const { + return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w; + } +}; + +inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; } + +// Create a triangle centroid +inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1, + const uv_float2 &v2) { + return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)}; +} + +inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1, + const uv_float3 &v2) { + return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y), + calc_mean(v0.z, v1.z, v2.z)}; +} + +// Helper functions for vector math +inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) { + return {a.x - b.x, a.y - b.y}; +} + +inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) { + return {a.x - b.x, a.y - b.y, a.z - b.z}; +} + +inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) { + return {a.x + b.x, a.y + b.y}; +} + +inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline uv_float2 operator*(const uv_float2 &a, float scalar) { + return {a.x * scalar, a.y * scalar}; +} + +inline uv_float3 operator*(const uv_float3 &a, float scalar) { + return {a.x * scalar, a.y * scalar, a.z * scalar}; +} + +inline float dot(const uv_float2 &a, const uv_float2 &b) { + return a.x * b.x + a.y * b.y; +} + +inline float dot(const uv_float3 &a, const uv_float3 &b) { + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +inline float cross(const uv_float2 &a, const uv_float2 &b) { + return a.x * b.y - a.y * b.x; +} + +inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) { + return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x}; +} + +inline uv_float2 abs_vec(const uv_float2 &v) { + return {std::abs(v.x), std::abs(v.y)}; +} + +inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) { + return {std::min(a.x, b.x), std::min(a.y, b.y)}; +} + +inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) { + return {std::max(a.x, b.x), std::max(a.y, b.y)}; +} + +inline float distance_to(const uv_float2 &a, const uv_float2 &b) { + return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2)); +} + +inline float distance_to(const uv_float3 &a, const uv_float3 &b) { + return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) + + std::pow(a.z - b.z, 2)); +} + +inline uv_float2 normalize(const uv_float2 &v) { + float len = std::sqrt(v.x * v.x + v.y * v.y); + return {v.x / len, v.y / len}; +} + +inline uv_float3 normalize(const uv_float3 &v) { + float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z); + return {v.x / len, v.y / len, v.z / len}; +} + +inline float magnitude(const uv_float3 &v) { + return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z); +} + +struct Matrix4 { + std::array, 4> m; + + Matrix4() { + for (auto &row : m) { + row.fill(0.0f); + } + m[3][3] = 1.0f; // Identity matrix for 4th row and column + } + + void set(float m00, float m01, float m02, float m03, float m10, float m11, + float m12, float m13, float m20, float m21, float m22, float m23, + float m30, float m31, float m32, float m33) { + m[0][0] = m00; + m[0][1] = m01; + m[0][2] = m02; + m[0][3] = m03; + m[1][0] = m10; + m[1][1] = m11; + m[1][2] = m12; + m[1][3] = m13; + m[2][0] = m20; + m[2][1] = m21; + m[2][2] = m22; + m[2][3] = m23; + m[3][0] = m30; + m[3][1] = m31; + m[3][2] = m32; + m[3][3] = m33; + } + + float determinant() const { + return m[0][3] * m[1][2] * m[2][1] * m[3][0] - + m[0][2] * m[1][3] * m[2][1] * m[3][0] - + m[0][3] * m[1][1] * m[2][2] * m[3][0] + + m[0][1] * m[1][3] * m[2][2] * m[3][0] + + m[0][2] * m[1][1] * m[2][3] * m[3][0] - + m[0][1] * m[1][2] * m[2][3] * m[3][0] - + m[0][3] * m[1][2] * m[2][0] * m[3][1] + + m[0][2] * m[1][3] * m[2][0] * m[3][1] + + m[0][3] * m[1][0] * m[2][2] * m[3][1] - + m[0][0] * m[1][3] * m[2][2] * m[3][1] - + m[0][2] * m[1][0] * m[2][3] * m[3][1] + + m[0][0] * m[1][2] * m[2][3] * m[3][1] + + m[0][3] * m[1][1] * m[2][0] * m[3][2] - + m[0][1] * m[1][3] * m[2][0] * m[3][2] - + m[0][3] * m[1][0] * m[2][1] * m[3][2] + + m[0][0] * m[1][3] * m[2][1] * m[3][2] + + m[0][1] * m[1][0] * m[2][3] * m[3][2] - + m[0][0] * m[1][1] * m[2][3] * m[3][2] - + m[0][2] * m[1][1] * m[2][0] * m[3][3] + + m[0][1] * m[1][2] * m[2][0] * m[3][3] + + m[0][2] * m[1][0] * m[2][1] * m[3][3] - + m[0][0] * m[1][2] * m[2][1] * m[3][3] - + m[0][1] * m[1][0] * m[2][2] * m[3][3] + + m[0][0] * m[1][1] * m[2][2] * m[3][3]; + } + + Matrix4 operator*(const Matrix4 &other) const { + Matrix4 result; + for (int row = 0; row < 4; ++row) { + for (int col = 0; col < 4; ++col) { + result.m[row][col] = + m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] + + m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col]; + } + } + return result; + } + + Matrix4 operator*(float scalar) const { + Matrix4 result = *this; + for (auto &row : result.m) { + for (auto &element : row) { + element *= scalar; + } + } + return result; + } + + Matrix4 operator+(const Matrix4 &other) const { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.m[i][j] = m[i][j] + other.m[i][j]; + } + } + return result; + } + + Matrix4 operator-(const Matrix4 &other) const { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.m[i][j] = m[i][j] - other.m[i][j]; + } + } + return result; + } + + float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; } + + Matrix4 identity() const { + Matrix4 identity; + identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1); + return identity; + } + + Matrix4 power(int exp) const { + if (exp == 0) + return identity(); + if (exp == 1) + return *this; + + Matrix4 result = *this; + for (int i = 1; i < exp; ++i) { + result = result * (*this); + } + return result; + } + + void print() { + // Print all entries in 4 rows with 4 columns + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + std::cout << m[i][j] << " "; + } + std::cout << std::endl; + } + } + + bool invert() { + double inv[16], det; + double mArr[16]; + + // Convert the matrix to a 1D array for easier manipulation + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + mArr[i * 4 + j] = static_cast(m[i][j]); + } + } + + inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] - + mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] + + mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10]; + + inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] + + mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] - + mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10]; + + inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] - + mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] + + mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9]; + + inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] + + mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] - + mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9]; + + inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] + + mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] - + mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10]; + + inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] - + mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] + + mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10]; + + inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] + + mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] - + mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9]; + + inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] - + mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] + + mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9]; + + inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] - + mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] + + mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6]; + + inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] + + mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] - + mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6]; + + inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] - + mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] + + mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5]; + + inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] + + mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] - + mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5]; + + inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] + + mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] - + mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6]; + + inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] - + mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] + + mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6]; + + inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] + + mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] - + mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5]; + + inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] - + mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] + + mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5]; + + det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] + + mArr[3] * inv[12]; + + if (fabs(det) < 1e-6) { + return false; + } + + det = 1.0 / det; + + for (int i = 0; i < 16; i++) { + inv[i] *= det; + } + + // Convert the 1D array back to the 4x4 matrix + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + m[i][j] = static_cast(inv[i * 4 + j]); + } + } + + return true; + } +}; + +inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) { + float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] + + v.z * matrix.m[0][2] + matrix.m[0][3]; + float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] + + v.z * matrix.m[1][2] + matrix.m[1][3]; + float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] + + v.z * matrix.m[2][2] + matrix.m[2][3]; + float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] + + matrix.m[3][3]; + + if (std::fabs(w) > EPSILON) { + newX /= w; + newY /= w; + newZ /= w; + } + + v.x = newX; + v.y = newY; + v.z = newZ; +} diff --git a/uv_unwrapper/uv_unwrapper/csrc/intersect.cpp b/uv_unwrapper/uv_unwrapper/csrc/intersect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d54b2c7f04cb38d79cf1d091393e02f7459fcad --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/intersect.cpp @@ -0,0 +1,702 @@ +#include "intersect.h" +#include "bvh.h" +#include +#include +#include +#include +#include + +bool triangle_aabb_intersection(const uv_float2 &aabbMin, + const uv_float2 &aabbMax, const uv_float2 &v0, + const uv_float2 &v1, const uv_float2 &v2) { + // Convert the min and max aabb defintion to left, right, top, bottom + float l = aabbMin.x; + float r = aabbMax.x; + float t = aabbMin.y; + float b = aabbMax.y; + + int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) | + ((v0.y > b) ? 8 : 0); + if (b0 == 3) + return true; + + int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) | + ((v1.y > b) ? 8 : 0); + if (b1 == 3) + return true; + + int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) | + ((v2.y > b) ? 8 : 0); + if (b2 == 3) + return true; + + float m, c, s; + + int i0 = b0 ^ b1; + if (i0 != 0) { + if (v1.x != v0.x) { + m = (v1.y - v0.y) / (v1.x - v0.x); + c = v0.y - (m * v0.x); + if (i0 & 1) { + s = m * l + c; + if (s >= t && s <= b) + return true; + } + if (i0 & 2) { + s = (t - c) / m; + if (s >= l && s <= r) + return true; + } + if (i0 & 4) { + s = m * r + c; + if (s >= t && s <= b) + return true; + } + if (i0 & 8) { + s = (b - c) / m; + if (s >= l && s <= r) + return true; + } + } else { + if (l == v0.x || r == v0.x) + return true; + if (v0.x > l && v0.x < r) + return true; + } + } + + int i1 = b1 ^ b2; + if (i1 != 0) { + if (v2.x != v1.x) { + m = (v2.y - v1.y) / (v2.x - v1.x); + c = v1.y - (m * v1.x); + if (i1 & 1) { + s = m * l + c; + if (s >= t && s <= b) + return true; + } + if (i1 & 2) { + s = (t - c) / m; + if (s >= l && s <= r) + return true; + } + if (i1 & 4) { + s = m * r + c; + if (s >= t && s <= b) + return true; + } + if (i1 & 8) { + s = (b - c) / m; + if (s >= l && s <= r) + return true; + } + } else { + if (l == v1.x || r == v1.x) + return true; + if (v1.x > l && v1.x < r) + return true; + } + } + + int i2 = b0 ^ b2; + if (i2 != 0) { + if (v2.x != v0.x) { + m = (v2.y - v0.y) / (v2.x - v0.x); + c = v0.y - (m * v0.x); + if (i2 & 1) { + s = m * l + c; + if (s >= t && s <= b) + return true; + } + if (i2 & 2) { + s = (t - c) / m; + if (s >= l && s <= r) + return true; + } + if (i2 & 4) { + s = m * r + c; + if (s >= t && s <= b) + return true; + } + if (i2 & 8) { + s = (b - c) / m; + if (s >= l && s <= r) + return true; + } + } else { + if (l == v0.x || r == v0.x) + return true; + if (v0.x > l && v0.x < r) + return true; + } + } + + // Bounding box check + float tbb_l = std::min(v0.x, std::min(v1.x, v2.x)); + float tbb_t = std::min(v0.y, std::min(v1.y, v2.y)); + float tbb_r = std::max(v0.x, std::max(v1.x, v2.x)); + float tbb_b = std::max(v0.y, std::max(v1.y, v2.y)); + + if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) { + float v0x = v2.x - v0.x; + float v0y = v2.y - v0.y; + float v1x = v1.x - v0.x; + float v1y = v1.y - v0.y; + float v2x, v2y; + + float dot00, dot01, dot02, dot11, dot12, invDenom, u, v; + + // Top-left corner + v2x = l - v0.x; + v2y = t - v0.y; + + dot00 = v0x * v0x + v0y * v0y; + dot01 = v0x * v1x + v0y * v1y; + dot02 = v0x * v2x + v0y * v2y; + dot11 = v1x * v1x + v1y * v1y; + dot12 = v1x * v2x + v1y * v2y; + + invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01); + u = (dot11 * dot02 - dot01 * dot12) * invDenom; + v = (dot00 * dot12 - dot01 * dot02) * invDenom; + + if (u >= 0 && v >= 0 && (u + v) <= 1) + return true; + + // Bottom-left corner + v2x = l - v0.x; + v2y = b - v0.y; + + dot02 = v0x * v2x + v0y * v2y; + dot12 = v1x * v2x + v1y * v2y; + + u = (dot11 * dot02 - dot01 * dot12) * invDenom; + v = (dot00 * dot12 - dot01 * dot02) * invDenom; + + if (u >= 0 && v >= 0 && (u + v) <= 1) + return true; + + // Bottom-right corner + v2x = r - v0.x; + v2y = b - v0.y; + + dot02 = v0x * v2x + v0y * v2y; + dot12 = v1x * v2x + v1y * v2y; + + u = (dot11 * dot02 - dot01 * dot12) * invDenom; + v = (dot00 * dot12 - dot01 * dot02) * invDenom; + + if (u >= 0 && v >= 0 && (u + v) <= 1) + return true; + + // Top-right corner + v2x = r - v0.x; + v2y = t - v0.y; + + dot02 = v0x * v2x + v0y * v2y; + dot12 = v1x * v2x + v1y * v2y; + + u = (dot11 * dot02 - dot01 * dot12) * invDenom; + v = (dot00 * dot12 - dot01 * dot02) * invDenom; + + if (u >= 0 && v >= 0 && (u + v) <= 1) + return true; + } + + return false; +} + +void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) { + float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y)); + + // If the determinant is negative, the triangle is oriented clockwise + if (det < 0) { + // Swap vertices b and c to ensure counter-clockwise winding + std::swap(b, c); + } +} + +struct Triangle { + uv_float3 a, b, c; + + Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1) + : a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {} + + Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1) + : a(p1), b(q1), c(r1) {} + + void getNormal(uv_float3 &normal) const { + uv_float3 u = b - a; + uv_float3 v = c - a; + normal = normalize(cross(u, v)); + } +}; + +bool isTriDegenerated(const Triangle &tri) { + uv_float3 u = tri.a - tri.b; + uv_float3 v = tri.a - tri.c; + uv_float3 cr = cross(u, v); + return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON; +} + +int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c, + const uv_float3 &d) { + Matrix4 _matrix4; + _matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y, + d.z, 1); + float det = _matrix4.determinant(); + + if (det < -EPSILON) + return -1; + else if (det > EPSILON) + return 1; + else + return 0; +} + +int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) { + float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y)); + + if (det < -EPSILON) + return -1; + else if (det > EPSILON) + return 1; + else + return 0; +} + +int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) { + uv_float2 a_2d = {a.x, a.y}; + uv_float2 b_2d = {b.x, b.y}; + uv_float2 c_2d = {c.x, c.y}; + return orient2D(a_2d, b_2d, c_2d); +} + +void permuteTriLeft(Triangle &tri) { + uv_float3 tmp = tri.a; + tri.a = tri.b; + tri.b = tri.c; + tri.c = tmp; +} + +void permuteTriRight(Triangle &tri) { + uv_float3 tmp = tri.c; + tri.c = tri.b; + tri.b = tri.a; + tri.a = tmp; +} + +void makeTriCounterClockwise(Triangle &tri) { + if (orient2D(tri.a, tri.b, tri.c) < 0) { + uv_float3 tmp = tri.c; + tri.c = tri.b; + tri.b = tmp; + } +} + +void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p, + const uv_float3 &n, uv_float3 &target) { + uv_float3 u = b - a; + uv_float3 v = a - p; + float dot1 = dot(n, u); + float dot2 = dot(n, v); + u = u * (-dot2 / dot1); + target = a + u; +} + +void computeLineIntersection(const Triangle &t1, const Triangle &t2, + std::vector &target) { + uv_float3 n1, n2; + t1.getNormal(n1); + t2.getNormal(n2); + + int o1 = orient3D(t1.a, t1.c, t2.b, t2.a); + int o2 = orient3D(t1.a, t1.b, t2.c, t2.a); + + uv_float3 i1, i2; + + if (o1 > 0) { + if (o2 > 0) { + intersectPlane(t1.a, t1.c, t2.a, n2, i1); + intersectPlane(t2.a, t2.c, t1.a, n1, i2); + } else { + intersectPlane(t1.a, t1.c, t2.a, n2, i1); + intersectPlane(t1.a, t1.b, t2.a, n2, i2); + } + } else { + if (o2 > 0) { + intersectPlane(t2.a, t2.b, t1.a, n1, i1); + intersectPlane(t2.a, t2.c, t1.a, n1, i2); + } else { + intersectPlane(t2.a, t2.b, t1.a, n1, i1); + intersectPlane(t1.a, t1.b, t2.a, n2, i2); + } + } + + target.push_back(i1); + if (distance_to(i1, i2) >= EPSILON) { + target.push_back(i2); + } +} + +void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) { + // Permute a, b, c so that a is alone on its side + if (oa == ob) { + // c is alone, permute right so c becomes a + permuteTriRight(tri); + } else if (oa == oc) { + // b is alone, permute so b becomes a + permuteTriLeft(tri); + } else if (ob != oc) { + // In case a, b, c have different orientation, put a on positive side + if (ob > 0) { + permuteTriLeft(tri); + } else if (oc > 0) { + permuteTriRight(tri); + } + } +} + +void makeTriAVertexPositive(Triangle &tri, const Triangle &other) { + int o = orient3D(other.a, other.b, other.c, tri.a); + if (o < 0) { + std::swap(tri.b, tri.c); + } +} + +bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c, + std::vector *target = nullptr) { + int o2a = orient3D(t1.a, t1.b, t1.c, t2.a); + int o2b = orient3D(t1.a, t1.b, t1.c, t2.b); + int o2c = orient3D(t1.a, t1.b, t1.c, t2.c); + + if (o2a == o2b && o2a == o2c) { + return false; + } + + // Make a vertex alone on its side for both triangles + makeTriAVertexAlone(t1, o1a, o1b, o1c); + makeTriAVertexAlone(t2, o2a, o2b, o2c); + + // Ensure the vertex on the positive side + makeTriAVertexPositive(t2, t1); + makeTriAVertexPositive(t1, t2); + + int o1 = orient3D(t1.a, t1.b, t2.a, t2.b); + int o2 = orient3D(t1.a, t1.c, t2.c, t2.a); + + if (o1 <= 0 && o2 <= 0) { + if (target) { + computeLineIntersection(t1, t2, *target); + } + return true; + } + + return false; +} + +void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1, + const uv_float3 &a2, const uv_float3 &b2, + uv_float3 &target) { + float dx1 = a1.x - b1.x; + float dx2 = a2.x - b2.x; + float dy1 = a1.y - b1.y; + float dy2 = a2.y - b2.y; + + float D = dx1 * dy2 - dx2 * dy1; + + float n1 = a1.x * b1.y - a1.y * b1.x; + float n2 = a2.x * b2.y - a2.y * b2.x; + + target.x = (n1 * dx2 - n2 * dx1) / D; + target.y = (n1 * dy2 - n2 * dy1) / D; + target.z = 0; +} + +void clipTriangle(const Triangle &t1, const Triangle &t2, + std::vector &target) { + std::vector clip = {t1.a, t1.b, t1.c}; + std::vector output = {t2.a, t2.b, t2.c}; + std::vector orients(output.size() * 3, 0); + uv_float3 inter; + + for (int i = 0; i < 3; ++i) { + const int i_prev = (i + 2) % 3; + std::vector input; + std::copy(output.begin(), output.end(), std::back_inserter(input)); + output.clear(); + + for (size_t j = 0; j < input.size(); ++j) { + orients[j] = orient2D(clip[i_prev], clip[i], input[j]); + } + + for (size_t j = 0; j < input.size(); ++j) { + const int j_prev = (j - 1 + input.size()) % input.size(); + + if (orients[j] >= 0) { + if (orients[j_prev] < 0) { + linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], + inter); + output.push_back({inter.x, inter.y, inter.z}); + } + output.push_back({input[j].x, input[j].y, input[j].z}); + } else if (orients[j_prev] >= 0) { + linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter); + output.push_back({inter.x, inter.y, inter.z}); + } + } + } + + // Clear duplicated points + for (const auto &point : output) { + int j = 0; + bool sameFound = false; + while (!sameFound && j < target.size()) { + sameFound = distance_to(point, target[j]) <= 1e-6; + j++; + } + + if (!sameFound) { + target.push_back(point); + } + } +} + +bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) { + const uv_float3 &p1 = t1.a; + const uv_float3 &q1 = t1.b; + const uv_float3 &r1 = t1.c; + const uv_float3 &p2 = t2.a; + const uv_float3 &r2 = t2.c; + + if (orient2D(r2, p2, q1) >= 0) { // I + if (orient2D(r2, p1, q1) >= 0) { // II.a + if (orient2D(p1, p2, q1) >= 0) { // III.a + return true; + } else { + if (orient2D(p1, p2, r1) >= 0) { // IV.a + if (orient2D(q1, r1, p2) >= 0) { // V + return true; + } + } + } + } + } else { + if (orient2D(r2, p2, r1) >= 0) { // II.b + if (orient2D(q1, r1, r2) >= 0) { // III.b + if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper) + return true; + } + } + } + } + + return false; +} + +bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) { + const uv_float3 &p1 = t1.a; + const uv_float3 &q1 = t1.b; + const uv_float3 &r1 = t1.c; + const uv_float3 &p2 = t2.a; + const uv_float3 &q2 = t2.b; + const uv_float3 &r2 = t2.c; + + if (orient2D(r2, p2, q1) >= 0) { // I + if (orient2D(q2, r2, q1) >= 0) { // II.a + if (orient2D(p1, p2, q1) >= 0) { // III.a + if (orient2D(p1, q2, q1) <= 0) { // IV.a + return true; + } + } else { + if (orient2D(p1, p2, r1) >= 0) { // IV.b + if (orient2D(r2, p2, r1) <= 0) { // V.a + return true; + } + } + } + } else { + if (orient2D(p1, q2, q1) <= 0) { // III.b + if (orient2D(q2, r2, r1) >= 0) { // IV.c + if (orient2D(q1, r1, q2) >= 0) { // V.b + return true; + } + } + } + } + } else { + if (orient2D(r2, p2, r1) >= 0) { // II.b + if (orient2D(q1, r1, r2) >= 0) { // III.c + if (orient2D(r1, p1, p2) >= 0) { // IV.d + return true; + } + } else { + if (orient2D(q1, r1, q2) >= 0) { // IV.e + if (orient2D(q2, r2, r1) >= 0) { // V.c + return true; + } + } + } + } + } + + return false; +} + +bool coplanarIntersect(Triangle &t1, Triangle &t2, + std::vector *target = nullptr) { + uv_float3 normal, u, v; + t1.getNormal(normal); + normal = normalize(normal); + u = normalize(t1.a - t1.b); + v = cross(normal, u); + + // Move basis to t1.a + u = u + t1.a; + v = v + t1.a; + normal = normal + t1.a; + + Matrix4 _matrix; + _matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z, + u.z, v.z, normal.z, 1, 1, 1, 1); + + Matrix4 _affineMatrix; + _affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1); + + _matrix.invert(); // Invert the _matrix + _matrix = _affineMatrix * _matrix; + + // Apply transformation + apply_matrix4(t1.a, _matrix); + apply_matrix4(t1.b, _matrix); + apply_matrix4(t1.c, _matrix); + apply_matrix4(t2.a, _matrix); + apply_matrix4(t2.b, _matrix); + apply_matrix4(t2.c, _matrix); + + makeTriCounterClockwise(t1); + makeTriCounterClockwise(t2); + + const uv_float3 &p1 = t1.a; + const uv_float3 &p2 = t2.a; + const uv_float3 &q2 = t2.b; + const uv_float3 &r2 = t2.c; + + int o_p2q2 = orient2D(p2, q2, p1); + int o_q2r2 = orient2D(q2, r2, p1); + int o_r2p2 = orient2D(r2, p2, p1); + + bool intersecting = false; + if (o_p2q2 >= 0) { + if (o_q2r2 >= 0) { + if (o_r2p2 >= 0) { + // + + + + intersecting = true; + } else { + // + + - + intersecting = intersectionTypeR1(t1, t2); + } + } else { + if (o_r2p2 >= 0) { + // + - + + permuteTriRight(t2); + intersecting = intersectionTypeR1(t1, t2); + } else { + // + - - + intersecting = intersectionTypeR2(t1, t2); + } + } + } else { + if (o_q2r2 >= 0) { + if (o_r2p2 >= 0) { + // - + + + permuteTriLeft(t2); + intersecting = intersectionTypeR1(t1, t2); + } else { + // - + - + permuteTriLeft(t2); + intersecting = intersectionTypeR2(t1, t2); + } + } else { + if (o_r2p2 >= 0) { + // - - + + permuteTriRight(t2); + intersecting = intersectionTypeR2(t1, t2); + } else { + // - - - + std::cerr << "Triangles should not be flat." << std::endl; + return false; + } + } + } + + if (intersecting && target) { + clipTriangle(t1, t2, *target); + + _matrix.invert(); + // Apply the transform to each target point + for (int i = 0; i < target->size(); ++i) { + apply_matrix4(target->at(i), _matrix); + } + } + + return intersecting; +} + +// Helper function to calculate the area of a polygon +float polygon_area(const std::vector &polygon) { + if (polygon.size() < 3) + return 0.0f; // Not a polygon + + uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector + + // Calculate the cross product of edges around the polygon + for (size_t i = 0; i < polygon.size(); ++i) { + uv_float3 p1 = polygon[i]; + uv_float3 p2 = polygon[(i + 1) % polygon.size()]; + + normal = normal + cross(p1, p2); // Accumulate the normal vector + } + + float area = + magnitude(normal) / 2.0f; // Area is half the magnitude of the normal + return area; +} + +bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1, + uv_float2 p2, uv_float2 q2, uv_float2 r2) { + Triangle t1(p1, q1, r1); + Triangle t2(p2, q2, r2); + + if (isTriDegenerated(t1) || isTriDegenerated(t2)) { + // std::cerr << "Degenerated triangles provided, skipping." << std::endl; + return false; + } + + int o1a = orient3D(t2.a, t2.b, t2.c, t1.a); + int o1b = orient3D(t2.a, t2.b, t2.c, t1.b); + int o1c = orient3D(t2.a, t2.b, t2.c, t1.c); + + std::vector intersections; + bool intersects; + + if (o1a == o1b && o1a == o1c) // [[likely]] + { + intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections); + } else // [[unlikely]] + { + intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections); + } + + if (intersects) { + float area = polygon_area(intersections); + + // std::cout << "Intersection area: " << area << std::endl; + if (area < 1e-10f || std::isfinite(area) == false) { + // std::cout<<"Invalid area: " << area << std::endl; + return false; // Ignore intersection if the area is too small + } + } + + return intersects; +} diff --git a/uv_unwrapper/uv_unwrapper/csrc/intersect.h b/uv_unwrapper/uv_unwrapper/csrc/intersect.h new file mode 100644 index 0000000000000000000000000000000000000000..9ecfe80555592fb56449854aa9a937059c708c7a --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/intersect.h @@ -0,0 +1,10 @@ +#pragma once + +#include "common.h" +#include + +bool triangle_aabb_intersection(const uv_float2 &aabb_min, + const uv_float2 &aabb_max, const uv_float2 &v0, + const uv_float2 &v1, const uv_float2 &v2); +bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1, + uv_float2 p2, uv_float2 q2, uv_float2 r2); diff --git a/uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp b/uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e76bd62b116caaf2512c67e483772a17b8e44c60 --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp @@ -0,0 +1,271 @@ +#include "bvh.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// #define TIMING + +#if defined(_MSC_VER) +#include +typedef SSIZE_T ssize_t; +#endif + +namespace UVUnwrapper { +void create_bvhs(BVH *bvhs, Triangle *triangles, + std::vector> &triangle_per_face, int num_faces, + int start, int end) { +#pragma omp parallel for + for (int i = start; i < end; i++) { + int num_triangles = triangle_per_face[i].size(); + Triangle *triangles_per_face = new Triangle[num_triangles]; + int *indices = new int[num_triangles]; + int j = 0; + for (int idx : triangle_per_face[i]) { + triangles_per_face[j] = triangles[idx]; + indices[j++] = idx; + } + // Each thread writes to it's own memory space + // First check if the number of triangles is 0 + if (num_triangles == 0) { + bvhs[i - start] = std::move(BVH()); // Default constructor + } else { + bvhs[i - start] = std::move( + BVH(triangles_per_face, indices, + num_triangles)); // BVH now handles memory of triangles_per_face + } + delete[] triangles_per_face; + } +} + +void perform_intersection_check(BVH *bvhs, int num_bvhs, Triangle *triangles, + uv_float3 *vertex_tri_centroids, + int64_t *assign_indices_ptr, + ssize_t num_indices, int offset, + std::vector> &triangle_per_face) { + std::vector> + unique_intersections; // Store unique intersections as pairs of triangle + // indices + +// Step 1: Detect intersections in parallel +#pragma omp parallel for + for (int i = 0; i < num_indices; i++) { + if (assign_indices_ptr[i] < offset) { + continue; + } + + Triangle cur_tri = triangles[i]; + auto &cur_bvh = bvhs[assign_indices_ptr[i] - offset]; + + if (cur_bvh.bvhNode == nullptr) { + continue; + } + + std::vector intersections = cur_bvh.Intersect(cur_tri); + + if (!intersections.empty()) { + +#pragma omp critical + { + for (int intersect : intersections) { + if (i != intersect) { + // Ensure we only store unique pairs (A, B) where A < B to avoid + // duplication + if (i < intersect) { + unique_intersections.push_back(std::make_pair(i, intersect)); + } else { + unique_intersections.push_back(std::make_pair(intersect, i)); + } + } + } + } + } + } + + // Step 2: Process unique intersections + for (int idx = 0; idx < unique_intersections.size(); idx++) { + int first = unique_intersections[idx].first; + int second = unique_intersections[idx].second; + + int i_idx = assign_indices_ptr[first]; + + int norm_idx = i_idx % 6; + int axis = (norm_idx < 2) ? 0 : (norm_idx < 4) ? 1 : 2; + bool use_max = (i_idx % 2) == 1; + + float pos_a = vertex_tri_centroids[first][axis]; + float pos_b = vertex_tri_centroids[second][axis]; + // Sort the intersections based on vertex_tri_centroids along the specified + // axis + if (use_max) { + if (pos_a < pos_b) { + std::swap(first, second); + } + } else { + if (pos_a > pos_b) { + std::swap(first, second); + } + } + + // Update the unique intersections + unique_intersections[idx].first = first; + unique_intersections[idx].second = second; + } + + // Now only get the second intersections from the pair and put them in a set + // The second intersection should always be the occluded triangle + std::set second_intersections; + for (int idx = 0; idx < (int)unique_intersections.size(); idx++) { + int second = unique_intersections[idx].second; + second_intersections.insert(second); + } + + for (int int_idx : second_intersections) { + // Move the second (occluded) triangle by 6 + int intersect_idx = assign_indices_ptr[int_idx]; + int new_index = intersect_idx + 6; + new_index = std::clamp(new_index, 0, 12); + + assign_indices_ptr[int_idx] = new_index; + triangle_per_face[intersect_idx].erase(int_idx); + triangle_per_face[new_index].insert(int_idx); + } +} + +torch::Tensor assign_faces_uv_to_atlas_index(torch::Tensor vertices, + torch::Tensor indices, + torch::Tensor face_uv, + torch::Tensor face_index) { + // Get the number of faces + int num_faces = indices.size(0); + torch::Tensor assign_indices = + torch::empty( + { + num_faces, + }, + torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU)) + .contiguous(); + + auto vert_accessor = vertices.accessor(); + auto indices_accessor = indices.accessor(); + auto face_uv_accessor = face_uv.accessor(); + + const int64_t *face_index_ptr = face_index.contiguous().data_ptr(); + int64_t *assign_indices_ptr = assign_indices.data_ptr(); + // copy face_index to assign_indices + memcpy(assign_indices_ptr, face_index_ptr, num_faces * sizeof(int64_t)); + +#ifdef TIMING + auto start = std::chrono::high_resolution_clock::now(); +#endif + uv_float3 *vertex_tri_centroids = new uv_float3[num_faces]; + Triangle *triangles = new Triangle[num_faces]; + + // Use std::set to store triangles for each face + std::vector> triangle_per_face; + triangle_per_face.resize(13); + +#pragma omp parallel for + for (int i = 0; i < num_faces; i++) { + int face_idx = i * 3; + triangles[i].v0 = {face_uv_accessor[face_idx + 0][0], + face_uv_accessor[face_idx + 0][1]}; + triangles[i].v1 = {face_uv_accessor[face_idx + 1][0], + face_uv_accessor[face_idx + 1][1]}; + triangles[i].v2 = {face_uv_accessor[face_idx + 2][0], + face_uv_accessor[face_idx + 2][1]}; + triangles[i].centroid = + triangle_centroid(triangles[i].v0, triangles[i].v1, triangles[i].v2); + + uv_float3 v0 = {vert_accessor[indices_accessor[i][0]][0], + vert_accessor[indices_accessor[i][0]][1], + vert_accessor[indices_accessor[i][0]][2]}; + uv_float3 v1 = {vert_accessor[indices_accessor[i][1]][0], + vert_accessor[indices_accessor[i][1]][1], + vert_accessor[indices_accessor[i][1]][2]}; + uv_float3 v2 = {vert_accessor[indices_accessor[i][2]][0], + vert_accessor[indices_accessor[i][2]][1], + vert_accessor[indices_accessor[i][2]][2]}; + vertex_tri_centroids[i] = triangle_centroid(v0, v1, v2); + +// Assign the triangle to the face index +#pragma omp critical + { triangle_per_face[face_index_ptr[i]].insert(i); } + } + +#ifdef TIMING + auto start_bvh = std::chrono::high_resolution_clock::now(); +#endif + + BVH *bvhs = new BVH[6]; + create_bvhs(bvhs, triangles, triangle_per_face, num_faces, 0, 6); + +#ifdef TIMING + auto end_bvh = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed_seconds = end_bvh - start_bvh; + std::cout << "BVH build time: " << elapsed_seconds.count() << "s\n"; + + auto start_intersection_1 = std::chrono::high_resolution_clock::now(); +#endif + + perform_intersection_check(bvhs, 6, triangles, vertex_tri_centroids, + assign_indices_ptr, num_faces, 0, + triangle_per_face); + +#ifdef TIMING + auto end_intersection_1 = std::chrono::high_resolution_clock::now(); + elapsed_seconds = end_intersection_1 - start_intersection_1; + std::cout << "Intersection 1 time: " << elapsed_seconds.count() << "s\n"; +#endif + // Create 6 new bvhs and delete the old ones + BVH *new_bvhs = new BVH[6]; + create_bvhs(new_bvhs, triangles, triangle_per_face, num_faces, 6, 12); + +#ifdef TIMING + auto end_bvh2 = std::chrono::high_resolution_clock::now(); + elapsed_seconds = end_bvh2 - end_intersection_1; + std::cout << "BVH 2 build time: " << elapsed_seconds.count() << "s\n"; + auto start_intersection_2 = std::chrono::high_resolution_clock::now(); +#endif + + perform_intersection_check(new_bvhs, 6, triangles, vertex_tri_centroids, + assign_indices_ptr, num_faces, 6, + triangle_per_face); + +#ifdef TIMING + auto end_intersection_2 = std::chrono::high_resolution_clock::now(); + elapsed_seconds = end_intersection_2 - start_intersection_2; + std::cout << "Intersection 2 time: " << elapsed_seconds.count() << "s\n"; + elapsed_seconds = end_intersection_2 - start; + std::cout << "Total time: " << elapsed_seconds.count() << "s\n"; +#endif + + // Cleanup + delete[] vertex_tri_centroids; + delete[] triangles; + delete[] bvhs; + delete[] new_bvhs; + + return assign_indices; +} + +// Registers _C as a Python extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +// Defines the operators +TORCH_LIBRARY(UVUnwrapper, m) { + m.def("assign_faces_uv_to_atlas_index(Tensor vertices, Tensor indices, " + "Tensor face_uv, Tensor face_index) -> Tensor"); +} + +// Registers CPP implementations +TORCH_LIBRARY_IMPL(UVUnwrapper, CPU, m) { + m.impl("assign_faces_uv_to_atlas_index", &assign_faces_uv_to_atlas_index); +} + +} // namespace UVUnwrapper diff --git a/uv_unwrapper/uv_unwrapper/unwrap.py b/uv_unwrapper/uv_unwrapper/unwrap.py new file mode 100644 index 0000000000000000000000000000000000000000..64becd4da535edbe5cd9a5ec5efe0cde01c64a7b --- /dev/null +++ b/uv_unwrapper/uv_unwrapper/unwrap.py @@ -0,0 +1,669 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class Unwrapper(nn.Module): + def __init__(self): + super().__init__() + + def _box_assign_vertex_to_cube_face( + self, + vertex_positions: Tensor, + vertex_normals: Tensor, + triangle_idxs: Tensor, + bbox: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Assigns each vertex to a cube face based on the face normal + + Args: + vertex_positions (Tensor, Nv 3, float): Vertex positions + vertex_normals (Tensor, Nv 3, float): Vertex normals + triangle_idxs (Tensor, Nf 3, int): Triangle indices + bbox (Tensor, 2 3, float): Bounding box of the mesh + + Returns: + Tensor, Nf 3 2, float: UV coordinates + Tensor, Nf, int: Cube face indices + """ + + # Test to not have a scaled model to fit the space better + # bbox_min = bbox[:1].mean(-1, keepdim=True) + # bbox_max = bbox[1:].mean(-1, keepdim=True) + # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min) + + # Create a [0, 1] normalized vertex position + v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1]) + # And to [-1, 1] + v_pos_normalized = 2.0 * v_pos_normalized - 1.0 + + # Get all vertex positions for each triangle + # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos? + v0 = v_pos_normalized[triangle_idxs[:, 0]] + v1 = v_pos_normalized[triangle_idxs[:, 1]] + v2 = v_pos_normalized[triangle_idxs[:, 2]] + tri_stack = torch.stack([v0, v1, v2], dim=1) + + vn0 = vertex_normals[triangle_idxs[:, 0]] + vn1 = vertex_normals[triangle_idxs[:, 1]] + vn2 = vertex_normals[triangle_idxs[:, 2]] + tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1) + + # Just average the normals per face + face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1) + + # Now decide based on the face normal in which box map we project + # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1) + abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1) + + axis = torch.tensor( + [ + [1, 0, 0], # 0 + [-1, 0, 0], # 1 + [0, 1, 0], # 2 + [0, -1, 0], # 3 + [0, 0, 1], # 4 + [0, 0, -1], # 5 + ], + device=face_normal.device, + dtype=face_normal.dtype, + ) + face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1) + index = face_normal_axis.argmax(-1) + + max_axis, uc, vc = ( + torch.ones_like(abs_x), + torch.zeros_like(tri_stack[..., :1]), + torch.zeros_like(tri_stack[..., :1]), + ) + mask_pos_x = index == 0 + max_axis[mask_pos_x] = abs_x[mask_pos_x] + uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2] + vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:] + + mask_neg_x = index == 1 + max_axis[mask_neg_x] = abs_x[mask_neg_x] + uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2] + vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:] + + mask_pos_y = index == 2 + max_axis[mask_pos_y] = abs_y[mask_pos_y] + uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1] + vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:] + + mask_neg_y = index == 3 + max_axis[mask_neg_y] = abs_y[mask_neg_y] + uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1] + vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:] + + mask_pos_z = index == 4 + max_axis[mask_pos_z] = abs_z[mask_pos_z] + uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1] + vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2] + + mask_neg_z = index == 5 + max_axis[mask_neg_z] = abs_z[mask_neg_z] + uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1] + vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2] + + # UC from [-1, 1] to [0, 1] + max_dim_div = max_axis.max(dim=0, keepdim=True).values + uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) + vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) + + uv = torch.stack([uc, vc], dim=-1) + + return uv, index + + def _assign_faces_uv_to_atlas_index( + self, + vertex_positions: Tensor, + triangle_idxs: Tensor, + face_uv: Tensor, + face_index: Tensor, + ) -> Tensor: # noqa: F821 + """ + Assigns the face UV to the atlas index + + Args: + vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions + triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices + face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates + face_index (Integer[Tensor, "Nf"]): Face indices + + Returns: + Integer[Tensor, "Nf"]: Atlas index + """ + return torch.ops.UVUnwrapper.assign_faces_uv_to_atlas_index( + vertex_positions.cpu(), + triangle_idxs.cpu(), + face_uv.view(-1, 2).cpu(), + face_index.cpu(), + ).to(vertex_positions.device) + + def _find_slice_offset_and_scale( + self, index: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # noqa: F821 + """ + Find the slice offset and scale + + Args: + index (Integer[Tensor, "Nf"]): Atlas index + + Returns: + Float[Tensor, "Nf"]: Offset x + Float[Tensor, "Nf"]: Offset y + Float[Tensor, "Nf"]: Division x + Float[Tensor, "Nf"]: Division y + """ + + # 6 due to the 6 cube faces + off = 1 / 3 + dupl_off = 1 / 6 + + # Here, we need to decide how to pack the textures in the case of overlap + def x_offset_calc(x, i): + offset_calc = i // 6 + # Initial coordinates - just 3x2 grid + if offset_calc == 0: + return off * x + else: + # Smaller 3x2 grid plus eventual shift to right for + # second overlap + return dupl_off * x + min(offset_calc - 1, 1) * 0.5 + + def y_offset_calc(x, i): + offset_calc = i // 6 + # Initial coordinates - just a 3x2 grid + if offset_calc == 0: + return off * x + else: + # Smaller coordinates in the lowest row + return dupl_off * x + off * 2 + + offset_x = torch.zeros_like(index, dtype=torch.float32) + offset_y = torch.zeros_like(index, dtype=torch.float32) + offset_x_vals = [0, 1, 2, 0, 1, 2] + offset_y_vals = [0, 0, 0, 1, 1, 1] + for i in range(index.max().item() + 1): + mask = index == i + if not mask.any(): + continue + offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i) + offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i) + + div_x = torch.full_like(index, 6 // 2, dtype=torch.float32) + # All overlap elements are saved in half scale + div_x[index >= 6] = 6 + div_y = div_x.clone() # Same for y + # Except for the random overlaps + div_x[index >= 12] = 2 + # But the random overlaps are saved in a large block in the lower thirds + div_y[index >= 12] = 3 + + return offset_x, offset_y, div_x, div_y + + def _calculate_tangents( + self, + vertex_positions: Tensor, + vertex_normals: Tensor, + triangle_idxs: Tensor, + face_uv: Tensor, + ) -> Tensor: + """ + Calculate the tangents for each triangle + + Args: + vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions + vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals + triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices + face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates + + Returns: + Float[Tensor, "Nf 3 4"]: Tangents + """ + vn_idx = [None] * 3 + pos = [None] * 3 + tex = face_uv.unbind(1) + for i in range(0, 3): + pos[i] = vertex_positions[triangle_idxs[:, i]] + # t_nrm_idx is always the same as t_pos_idx + vn_idx[i] = triangle_idxs[:, i] + + if torch.backends.mps.is_available(): + tangents = torch.zeros_like(vertex_normals).contiguous() + tansum = torch.zeros_like(vertex_normals).contiguous() + else: + tangents = torch.zeros_like(vertex_normals) + tansum = torch.zeros_like(vertex_normals) + + # Compute tangent space for each triangle + duv1 = tex[1] - tex[0] + duv2 = tex[2] - tex[0] + dpos1 = pos[1] - pos[0] + dpos2 = pos[2] - pos[0] + + tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] + + denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] + + # Avoid division by zero for degenerated texture coordinates + denom_safe = denom.clip(1e-6) + tang = tng_nom / denom_safe + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_( + 0, idx, torch.ones_like(tang) + ) # tansum[n_i] = tansum[n_i] + 1 + # Also normalize it. Here we do not normalize the individual triangles first so larger area + # triangles influence the tangent space more + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = F.normalize(tangents, dim=1) + tangents = F.normalize( + tangents + - (tangents * vertex_normals).sum(-1, keepdim=True) * vertex_normals + ) + + return tangents + + def _rotate_uv_slices_consistent_space( + self, + vertex_positions: Tensor, + vertex_normals: Tensor, + triangle_idxs: Tensor, + uv: Tensor, + index: Tensor, + ) -> Tensor: + """ + Rotate the UV slices so they are in a consistent space + + Args: + vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions + vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals + triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices + uv (Float[Tensor, "Nf 3 2"]): UV coordinates + index (Integer[Tensor, "Nf"]): Atlas index + + Returns: + Float[Tensor, "Nf 3 2"]: Rotated UV coordinates + """ + + tangents = self._calculate_tangents( + vertex_positions, vertex_normals, triangle_idxs, uv + ) + pos_stack = torch.stack( + [ + -vertex_positions[..., 1], + vertex_positions[..., 0], + torch.zeros_like(vertex_positions[..., 0]), + ], + dim=-1, + ) + expected_tangents = F.normalize( + torch.linalg.cross( + vertex_normals, + torch.linalg.cross(pos_stack, vertex_normals, dim=-1), + dim=-1, + ), + -1, + ) + + actual_tangents = tangents[triangle_idxs] + expected_tangents = expected_tangents[triangle_idxs] + + def rotation_matrix_2d(theta): + c, s = torch.cos(theta), torch.sin(theta) + return torch.tensor([[c, -s], [s, c]]) + + # Now find the rotation + index_mod = index % 6 # Shouldn't happen. Just for safety + for i in range(6): + mask = index_mod == i + if not mask.any(): + continue + + actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1)) + expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1)) + + dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent) + cross_product = ( + actual_mean_tangent[0] * expected_mean_tangent[1] + - actual_mean_tangent[1] * expected_mean_tangent[0] + ) + angle = torch.atan2(cross_product, dot_product) + + rot_matrix = rotation_matrix_2d(angle).to(mask.device) + # Center the uv coordinate to be in the range of -1 to 1 and 0 centered + uv_cur = uv[mask] * 2 - 1 # Center it first + # Rotate it + uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur) + + # Rescale uv[mask] to be within the 0-1 range + uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min()) + + return uv + + def _handle_slice_uvs( + self, + uv: Tensor, + index: Tensor, # noqa: F821 + island_padding: float, + max_index: int = 6 * 2, + ) -> Tensor: # noqa: F821 + """ + Handle the slice UVs + + Args: + uv (Float[Tensor, "Nf 3 2"]): UV coordinates + index (Integer[Tensor, "Nf"]): Atlas index + island_padding (float): Island padding + max_index (int): Maximum index + + Returns: + Float[Tensor, "Nf 3 2"]: Updated UV coordinates + + """ + uc, vc = uv.unbind(-1) + + # Get the second slice (The first overlap) + index_filter = [index == i for i in range(6, max_index)] + + # Normalize them to always fully fill the atlas patch + for i, fi in enumerate(index_filter): + if fi.sum() > 0: + # Scale the slice but only up to a factor of 2 + # This keeps the texture resolution with the first slice in line (Half space in UV) + uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip( + 0.5 + ) + vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip( + 0.5 + ) + + uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1) + vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1) + + return torch.stack([uc_padded, vc_padded], dim=-1) + + def _handle_remaining_uvs( + self, + uv: Tensor, + index: Tensor, # noqa: F821 + island_padding: float, + ) -> Tensor: + """ + Handle the remaining UVs (The ones that are not slices) + + Args: + uv (Float[Tensor, "Nf 3 2"]): UV coordinates + index (Integer[Tensor, "Nf"]): Atlas index + island_padding (float): Island padding + + Returns: + Float[Tensor, "Nf 3 2"]: Updated UV coordinates + """ + uc, vc = uv.unbind(-1) + # Get all remaining elements + remaining_filter = index >= 6 * 2 + squares_left = remaining_filter.sum() + + if squares_left == 0: + return uv + + uc = uc[remaining_filter] + vc = vc[remaining_filter] + + # Or remaining triangles are distributed in a rectangle + # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height + ratio = 0.5 * (1 / 3) # 1.5 + # sqrt(744/(0.5*(1/3))) + + mult = math.sqrt(squares_left / ratio) + num_square_width = int(math.ceil(0.5 * mult)) + num_square_height = int(math.ceil(squares_left / num_square_width)) + + width = 1 / num_square_width + height = 1 / num_square_height + + # The idea is again to keep the texture resolution consistent with the first slice + # This only occupys half the region in the texture chart but the scaling on the squares + # assumes full coverage. + # Now normalize the UVs with taking into account the maximum scaling + uc = (uc - uc.min(dim=1, keepdim=True).values) / ( + uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True) + ) + vc = (vc - vc.min(dim=1, keepdim=True).values) / ( + vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True) + ) + + # Add a small padding + uc = ( + uc * (1 - island_padding * num_square_width * 0.25) + + island_padding * num_square_width * 0.125 + ).clip(0, 1) + vc = ( + vc * (1 - island_padding * num_square_height * 0.25) + + island_padding * num_square_height * 0.125 + ).clip(0, 1) + + uc = uc * width + vc = vc * height + + # And calculate offsets for each element + idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32) + x_idx = idx % num_square_width + y_idx = idx // num_square_width + # And move each triangle to its own spot + uc = uc + x_idx[:, None] * width + vc = vc + y_idx[:, None] * height + + uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) + vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) + + uv[remaining_filter] = torch.stack([uc, vc], dim=-1) + + return uv + + def _distribute_individual_uvs_in_atlas( + self, + face_uv: Tensor, + assigned_faces: Tensor, + offset_x: Tensor, + offset_y: Tensor, + div_x: Tensor, + div_y: Tensor, + island_padding: float, + ) -> Tensor: + """ + Distribute the individual UVs in the atlas + + Args: + face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates + assigned_faces (Integer[Tensor, "Nf"]): Assigned faces + offset_x (Float[Tensor, "Nf"]): Offset x + offset_y (Float[Tensor, "Nf"]): Offset y + div_x (Float[Tensor, "Nf"]): Division x + div_y (Float[Tensor, "Nf"]): Division y + island_padding (float): Island padding + + Returns: + Float[Tensor, "Nf 3 2"]: Updated UV coordinates + """ + # Place the slice first + placed_uv = self._handle_slice_uvs(face_uv, assigned_faces, island_padding) + # Then handle the remaining overlap elements + placed_uv = self._handle_remaining_uvs( + placed_uv, assigned_faces, island_padding + ) + + uc, vc = placed_uv.unbind(-1) + uc = uc / div_x[:, None] + offset_x[:, None] + vc = vc / div_y[:, None] + offset_y[:, None] + + uv = torch.stack([uc, vc], dim=-1).view(-1, 2) + + return uv + + def _get_unique_face_uv( + self, + uv: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Get the unique face UV + + Args: + uv (Float[Tensor, "Nf 3 2"]): UV coordinates + + Returns: + Float[Tensor, "Utex 3"]: Unique UV coordinates + Integer[Tensor, "Nf"]: Vertex index + """ + unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0) + # And add the face to uv index mapping + vtex_idx = unique_idx.view(-1, 3) + + return unique_uv, vtex_idx + + def _align_mesh_with_main_axis( + self, vertex_positions: Tensor, vertex_normals: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Align the mesh with the main axis + + Args: + vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions + vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals + + Returns: + Float[Tensor, "Nv 3"]: Rotated vertex positions + Float[Tensor, "Nv 3"]: Rotated vertex normals + """ + + # Use pca to find the 2 main axis (third is derived by cross product) + # Set the random seed so it's repeatable + torch.manual_seed(0) + _, _, v = torch.pca_lowrank(vertex_positions, q=2) + main_axis, seconday_axis = v[:, 0], v[:, 1] + + main_axis = F.normalize(main_axis, eps=1e-6, dim=-1) # 3, + # Orthogonalize the second axis + seconday_axis = F.normalize( + seconday_axis + - (seconday_axis * main_axis).sum(-1, keepdim=True) * main_axis, + eps=1e-6, + dim=-1, + ) # 3, + # Create perpendicular third axis + third_axis = F.normalize( + torch.cross(main_axis, seconday_axis, dim=-1), dim=-1, eps=1e-6 + ) # 3, + + # Check to which canonical axis each aligns + main_axis_max_idx = main_axis.abs().argmax().item() + seconday_axis_max_idx = seconday_axis.abs().argmax().item() + third_axis_max_idx = third_axis.abs().argmax().item() + + # Now sort the axes based on the argmax so they align with thecanonoical axes + # If two axes have the same argmax move one of them + all_possible_axis = {0, 1, 2} + cur_index = 1 + while ( + len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) + != 3 + ): + # Find missing axis + missing_axis = all_possible_axis - set( + [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx] + ) + missing_axis = missing_axis.pop() + # Just assign it to third axis as it had the smallest contribution to the + # overall shape + if cur_index == 1: + third_axis_max_idx = missing_axis + elif cur_index == 2: + seconday_axis_max_idx = missing_axis + else: + raise ValueError("Could not find 3 unique axis") + cur_index += 1 + + if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3: + raise ValueError("Could not find 3 unique axis") + + axes = [None] * 3 + axes[main_axis_max_idx] = main_axis + axes[seconday_axis_max_idx] = seconday_axis + axes[third_axis_max_idx] = third_axis + # Create rotation matrix from the individual axes + rot_mat = torch.stack(axes, dim=1).T + + # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis + vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions) + vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals) + + return vertex_positions, vertex_normals + + def forward( + self, + vertex_positions: Tensor, + vertex_normals: Tensor, + triangle_idxs: Tensor, + island_padding: float, + ) -> Tuple[Tensor, Tensor]: + """ + Unwrap the mesh + + Args: + vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions + vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals + triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices + island_padding (float): Island padding + + Returns: + Float[Tensor, "Utex 3"]: Unique UV coordinates + Integer[Tensor, "Nf"]: Vertex index + """ + vertex_positions, vertex_normals = self._align_mesh_with_main_axis( + vertex_positions, vertex_normals + ) + bbox = torch.stack( + [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], + dim=0, + ) # 2, 3 + + face_uv, face_index = self._box_assign_vertex_to_cube_face( + vertex_positions, vertex_normals, triangle_idxs, bbox + ) + + face_uv = self._rotate_uv_slices_consistent_space( + vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index + ) + + assigned_atlas_index = self._assign_faces_uv_to_atlas_index( + vertex_positions, triangle_idxs, face_uv, face_index + ) + + offset_x, offset_y, div_x, div_y = self._find_slice_offset_and_scale( + assigned_atlas_index + ) + + placed_uv = self._distribute_individual_uvs_in_atlas( + face_uv, + assigned_atlas_index, + offset_x, + offset_y, + div_x, + div_y, + island_padding, + ) + + return self._get_unique_face_uv(placed_uv)