jammmmm commited on
Commit
38dbec8
·
1 Parent(s): 73dc205

Add spar3d demo files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +167 -0
  3. .pre-commit-config.yaml +24 -0
  4. LICENSE.md +51 -0
  5. README.md +3 -3
  6. __init__.py +358 -0
  7. demo_files/comp.gif +3 -0
  8. demo_files/examples/bird.png +3 -0
  9. demo_files/examples/castle.png +3 -0
  10. demo_files/examples/chest.png +3 -0
  11. demo_files/examples/doll.png +3 -0
  12. demo_files/examples/excavator.png +3 -0
  13. demo_files/examples/fish.png +3 -0
  14. demo_files/examples/horse-statue.png +3 -0
  15. demo_files/examples/penguin.png +3 -0
  16. demo_files/examples/pot.png +3 -0
  17. demo_files/examples/raccoon_wizard.png +3 -0
  18. demo_files/examples/stylized-rocks.png +3 -0
  19. demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
  20. demo_files/hdri/metro_noord_1k.hdr +0 -0
  21. demo_files/hdri/neon_photostudio_1k.hdr +0 -0
  22. demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
  23. demo_files/hdri/rainforest_trail_1k.hdr +0 -0
  24. demo_files/hdri/studio_small_08_1k.hdr +0 -0
  25. demo_files/hdri/urban_alley_01_1k.hdr +0 -0
  26. demo_files/turntable.gif +3 -0
  27. demo_files/workflows/spar3d_example.json +263 -0
  28. gradio_app.py +792 -0
  29. load/tets/160_tets.npz +3 -0
  30. requirements.txt +17 -0
  31. ruff.toml +3 -0
  32. run.py +180 -0
  33. spar3d/models/camera.py +32 -0
  34. spar3d/models/diffusion/gaussian_diffusion.py +524 -0
  35. spar3d/models/diffusion/sampler.py +134 -0
  36. spar3d/models/global_estimator/reni_estimator.py +112 -0
  37. spar3d/models/illumination/reni/components/film_siren.py +148 -0
  38. spar3d/models/illumination/reni/components/siren.py +118 -0
  39. spar3d/models/illumination/reni/components/transformer_decoder.py +189 -0
  40. spar3d/models/illumination/reni/components/vn_layers.py +548 -0
  41. spar3d/models/illumination/reni/env_map.py +93 -0
  42. spar3d/models/illumination/reni/field.py +736 -0
  43. spar3d/models/image_estimator/clip_based_estimator.py +184 -0
  44. spar3d/models/isosurface.py +229 -0
  45. spar3d/models/mesh.py +317 -0
  46. spar3d/models/network.py +223 -0
  47. spar3d/models/tokenizers/dinov2.py +1196 -0
  48. spar3d/models/tokenizers/image.py +99 -0
  49. spar3d/models/tokenizers/point.py +51 -0
  50. spar3d/models/tokenizers/triplane.py +49 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv*/
127
+ env/
128
+ venv*/
129
+ ENV/
130
+ env.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+ .vs/
163
+ .idea/
164
+ .vscode/
165
+
166
+ stabilityai/
167
+ output/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: check-ast
10
+ - id: check-merge-conflict
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ args: [--markdown-linebreak-ext=md]
15
+
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ # Ruff version.
18
+ rev: v0.3.5
19
+ hooks:
20
+ # Run the linter.
21
+ - id: ruff
22
+ args: [ --fix ]
23
+ # Run the formatter.
24
+ - id: ruff-format
LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+ Last Updated: July 5, 2024
3
+
4
+
5
+ I. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+
10
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
11
+
12
+
13
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
14
+
15
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
18
+
19
+ III. COMMERCIAL USE LICENSE
20
+
21
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
22
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
23
+
24
+ IV. GENERAL TERMS
25
+
26
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
27
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
28
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
29
+ c. Intellectual Property.
30
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
31
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
32
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
33
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
34
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
35
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
38
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
39
+
40
+ V. DEFINITIONS
41
+
42
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
43
+ "Agreement" means this Stability AI Community License Agreement.
44
+ "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
45
+ "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
46
+ "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
47
+ "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
48
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
49
+ "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
50
+ "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
51
+ "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Stable Point Aware 3d
3
  emoji: ⚡
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
+ title: Stable Point-Aware 3D
3
  emoji: ⚡
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.43.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  ---
11
 
__init__.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ import comfy.model_management
8
+ import folder_paths
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ from PIL import Image
13
+ from trimesh.exchange import gltf
14
+
15
+ sys.path.append(os.path.dirname(__file__))
16
+ from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
17
+ from spar3d.system import SPAR3D
18
+ from spar3d.utils import foreground_crop
19
+
20
+ SPAR3D_CATEGORY = "SPAR3D"
21
+ SPAR3D_MODEL_NAME = "stabilityai/spar3d"
22
+
23
+
24
+ class SPAR3DLoader:
25
+ CATEGORY = SPAR3D_CATEGORY
26
+ FUNCTION = "load"
27
+ RETURN_NAMES = ("spar3d_model",)
28
+ RETURN_TYPES = ("SPAR3D_MODEL",)
29
+
30
+ @classmethod
31
+ def INPUT_TYPES(cls):
32
+ return {"required": {}}
33
+
34
+ def load(self):
35
+ device = comfy.model_management.get_torch_device()
36
+ model = SPAR3D.from_pretrained(
37
+ SPAR3D_MODEL_NAME,
38
+ config_name="config.yaml",
39
+ weight_name="model.safetensors",
40
+ )
41
+ model.to(device)
42
+ model.eval()
43
+
44
+ return (model,)
45
+
46
+
47
+ class SPAR3DPreview:
48
+ CATEGORY = SPAR3D_CATEGORY
49
+ FUNCTION = "preview"
50
+ OUTPUT_NODE = True
51
+ RETURN_TYPES = ()
52
+
53
+ @classmethod
54
+ def INPUT_TYPES(s):
55
+ return {"required": {"mesh": ("MESH",)}}
56
+
57
+ def preview(self, mesh):
58
+ glbs = []
59
+ for m in mesh:
60
+ scene = trimesh.Scene(m)
61
+ glb_data = gltf.export_glb(scene, include_normals=True)
62
+ glb_base64 = base64.b64encode(glb_data).decode("utf-8")
63
+ glbs.append(glb_base64)
64
+ return {"ui": {"glbs": glbs}}
65
+
66
+
67
+ class SPAR3DSampler:
68
+ CATEGORY = SPAR3D_CATEGORY
69
+ FUNCTION = "predict"
70
+ RETURN_NAMES = ("mesh", "pointcloud")
71
+ RETURN_TYPES = ("MESH", "POINTCLOUD")
72
+
73
+ @classmethod
74
+ def INPUT_TYPES(s):
75
+ remesh_choices = ["none"]
76
+ if TRIANGLE_REMESH_AVAILABLE:
77
+ remesh_choices.append("triangle")
78
+ if QUAD_REMESH_AVAILABLE:
79
+ remesh_choices.append("quad")
80
+
81
+ opt_dict = {
82
+ "mask": ("MASK",),
83
+ "pointcloud": ("POINTCLOUD",),
84
+ "target_type": (["none", "vertex", "face"],),
85
+ "target_count": (
86
+ "INT",
87
+ {"default": 1000, "min": 3, "max": 20000, "step": 1},
88
+ ),
89
+ "guidance_scale": (
90
+ "FLOAT",
91
+ {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05},
92
+ ),
93
+ "seed": (
94
+ "INT",
95
+ {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1},
96
+ ),
97
+ }
98
+ if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
99
+ opt_dict["remesh"] = (remesh_choices,)
100
+
101
+ return {
102
+ "required": {
103
+ "model": ("SPAR3D_MODEL",),
104
+ "image": ("IMAGE",),
105
+ "foreground_ratio": (
106
+ "FLOAT",
107
+ {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01},
108
+ ),
109
+ "texture_resolution": (
110
+ "INT",
111
+ {"default": 1024, "min": 512, "max": 2048, "step": 256},
112
+ ),
113
+ },
114
+ "optional": opt_dict,
115
+ }
116
+
117
+ def predict(
118
+ s,
119
+ model,
120
+ image,
121
+ mask,
122
+ foreground_ratio,
123
+ texture_resolution,
124
+ pointcloud=None,
125
+ remesh="none",
126
+ target_type="none",
127
+ target_count=1000,
128
+ guidance_scale=3.0,
129
+ seed=42,
130
+ ):
131
+ if image.shape[0] != 1:
132
+ raise ValueError("Only one image can be processed at a time")
133
+
134
+ vertex_count = (
135
+ -1
136
+ if target_type == "none"
137
+ else (target_count // 2 if target_type == "face" else target_count)
138
+ )
139
+
140
+ pil_image = Image.fromarray(
141
+ torch.clamp(torch.round(255.0 * image[0]), 0, 255)
142
+ .type(torch.uint8)
143
+ .cpu()
144
+ .numpy()
145
+ )
146
+
147
+ if mask is not None:
148
+ print("Using Mask")
149
+ mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
150
+ np.uint8
151
+ )
152
+ mask_pil = Image.fromarray(mask_np, mode="L")
153
+ pil_image.putalpha(mask_pil)
154
+ else:
155
+ if image.shape[3] != 4:
156
+ print("No mask or alpha channel detected, Converting to RGBA")
157
+ pil_image = pil_image.convert("RGBA")
158
+
159
+ pil_image = foreground_crop(pil_image, foreground_ratio)
160
+
161
+ model.cfg.guidance_scale = guidance_scale
162
+ random.seed(seed)
163
+ torch.manual_seed(seed)
164
+ np.random.seed(seed)
165
+
166
+ print(remesh)
167
+ with torch.no_grad():
168
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
169
+ if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle":
170
+ raise ImportError(
171
+ "Triangle remeshing requires gpytoolbox to be installed"
172
+ )
173
+ if not QUAD_REMESH_AVAILABLE and remesh == "quad":
174
+ raise ImportError("Quad remeshing requires pynim to be installed")
175
+ mesh, glob_dict = model.run_image(
176
+ pil_image,
177
+ bake_resolution=texture_resolution,
178
+ pointcloud=pointcloud,
179
+ remesh=remesh,
180
+ vertex_count=vertex_count,
181
+ )
182
+
183
+ if mesh.vertices.shape[0] == 0:
184
+ raise ValueError("No subject detected in the image")
185
+
186
+ return (
187
+ [mesh],
188
+ glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(),
189
+ )
190
+
191
+
192
+ class SPAR3DSave:
193
+ CATEGORY = SPAR3D_CATEGORY
194
+ FUNCTION = "save"
195
+ OUTPUT_NODE = True
196
+ RETURN_TYPES = ()
197
+
198
+ @classmethod
199
+ def INPUT_TYPES(s):
200
+ return {
201
+ "required": {
202
+ "mesh": ("MESH",),
203
+ "filename_prefix": ("STRING", {"default": "SPAR3D"}),
204
+ }
205
+ }
206
+
207
+ def __init__(self):
208
+ self.type = "output"
209
+
210
+ def save(self, mesh, filename_prefix):
211
+ output_dir = folder_paths.get_output_directory()
212
+ glbs = []
213
+ for idx, m in enumerate(mesh):
214
+ scene = trimesh.Scene(m)
215
+ glb_data = gltf.export_glb(scene, include_normals=True)
216
+ logging.info(f"Generated GLB model with {len(glb_data)} bytes")
217
+
218
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
219
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
220
+ )
221
+ filename = filename.replace("%batch_num%", str(idx))
222
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
223
+ with open(out_path, "wb") as f:
224
+ f.write(glb_data)
225
+ glbs.append(base64.b64encode(glb_data).decode("utf-8"))
226
+ return {"ui": {"glbs": glbs}}
227
+
228
+
229
+ class SPAR3DPointCloudLoader:
230
+ CATEGORY = SPAR3D_CATEGORY
231
+ FUNCTION = "load_pointcloud"
232
+ RETURN_TYPES = ("POINTCLOUD",)
233
+ RETURN_NAMES = ("pointcloud",)
234
+
235
+ @classmethod
236
+ def INPUT_TYPES(cls):
237
+ return {
238
+ "required": {
239
+ "file": ("STRING", {"default": None}),
240
+ }
241
+ }
242
+
243
+ def load_pointcloud(self, file):
244
+ if file is None or file == "":
245
+ return (None,)
246
+ # Load the mesh using trimesh
247
+ mesh = trimesh.load(file)
248
+
249
+ # Extract vertices and colors
250
+ vertices = mesh.vertices
251
+
252
+ # Get vertex colors, defaulting to white if none exist
253
+ if mesh.visual.vertex_colors is not None:
254
+ colors = (
255
+ mesh.visual.vertex_colors[:, :3] / 255.0
256
+ ) # Convert 0-255 to 0-1 range
257
+ else:
258
+ colors = np.ones((len(vertices), 3))
259
+
260
+ # Interleave XYZ and RGB values
261
+ point_cloud = []
262
+ for vertex, color in zip(vertices, colors):
263
+ point_cloud.extend(
264
+ [
265
+ float(vertex[0]),
266
+ float(vertex[1]),
267
+ float(vertex[2]),
268
+ float(color[0]),
269
+ float(color[1]),
270
+ float(color[2]),
271
+ ]
272
+ )
273
+
274
+ return (point_cloud,)
275
+
276
+
277
+ class SPAR3DPointCloudSaver:
278
+ CATEGORY = SPAR3D_CATEGORY
279
+ FUNCTION = "save_pointcloud"
280
+ OUTPUT_NODE = True
281
+ RETURN_TYPES = ()
282
+
283
+ @classmethod
284
+ def INPUT_TYPES(s):
285
+ return {
286
+ "required": {
287
+ "pointcloud": ("POINTCLOUD",),
288
+ "filename_prefix": ("STRING", {"default": "SPAR3D"}),
289
+ }
290
+ }
291
+
292
+ def save_pointcloud(self, pointcloud, filename_prefix):
293
+ if pointcloud is None:
294
+ return {"ui": {"text": "No point cloud data to save"}}
295
+
296
+ # Reshape the flat list into points with XYZ and RGB
297
+ points = np.array(pointcloud).reshape(-1, 6)
298
+
299
+ # Create vertex array for PLY
300
+ vertex_array = np.zeros(
301
+ len(points),
302
+ dtype=[
303
+ ("x", "f4"),
304
+ ("y", "f4"),
305
+ ("z", "f4"),
306
+ ("red", "u1"),
307
+ ("green", "u1"),
308
+ ("blue", "u1"),
309
+ ],
310
+ )
311
+
312
+ # Fill vertex array
313
+ vertex_array["x"] = points[:, 0]
314
+ vertex_array["y"] = points[:, 1]
315
+ vertex_array["z"] = points[:, 2]
316
+ # Convert RGB from 0-1 to 0-255 range
317
+ vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8)
318
+ vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8)
319
+ vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8)
320
+
321
+ # Create PLY object
322
+ ply_data = trimesh.PointCloud(
323
+ vertices=points[:, :3], colors=points[:, 3:] * 255
324
+ )
325
+
326
+ # Save to file
327
+ output_dir = folder_paths.get_output_directory()
328
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
329
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
330
+ )
331
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply")
332
+
333
+ ply_data.export(out_path)
334
+
335
+ return {"ui": {"text": f"Saved point cloud to {out_path}"}}
336
+
337
+
338
+ NODE_DISPLAY_NAME_MAPPINGS = {
339
+ "SPAR3DLoader": "SPAR3D Loader",
340
+ "SPAR3DPreview": "SPAR3D Preview",
341
+ "SPAR3DSampler": "SPAR3D Sampler",
342
+ "SPAR3DSave": "SPAR3D Save",
343
+ "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader",
344
+ "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver",
345
+ }
346
+
347
+ NODE_CLASS_MAPPINGS = {
348
+ "SPAR3DLoader": SPAR3DLoader,
349
+ "SPAR3DPreview": SPAR3DPreview,
350
+ "SPAR3DSampler": SPAR3DSampler,
351
+ "SPAR3DSave": SPAR3DSave,
352
+ "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader,
353
+ "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver,
354
+ }
355
+
356
+ WEB_DIRECTORY = "./comfyui"
357
+
358
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
demo_files/comp.gif ADDED

Git LFS Details

  • SHA256: 6190ca0c3bd164d37152ba985abea53e642fe5e434ca0a932a3b2c4dce698f6b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
demo_files/examples/bird.png ADDED

Git LFS Details

  • SHA256: 83373e2b75ebaad76b6fe093973ea1dc96c92527c8376062cf520ed9215f3e82
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
demo_files/examples/castle.png ADDED

Git LFS Details

  • SHA256: ededd2fe4c122cadfb4f2a485dfd82f83dc1ec6446c7a799d5fc1e1f103ae4b1
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
demo_files/examples/chest.png ADDED

Git LFS Details

  • SHA256: f1eec59b35c63aa50942edff37f0cbdea7d8360cd036a4b7eb9460afdfcbabd9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
demo_files/examples/doll.png ADDED

Git LFS Details

  • SHA256: fc5af86defd0a4fd7285e17a0eb8a108b9f33774408c194a594964d8d6e66c26
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
demo_files/examples/excavator.png ADDED

Git LFS Details

  • SHA256: 6f68c6ba4a9dc884d3786d98c4f0d835682bad02e85716d3a60fd2feedcb03d8
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
demo_files/examples/fish.png ADDED

Git LFS Details

  • SHA256: cd623d8b654de81e022e3741576a0d08dd26d6ba92ee1989605347ef26c399bb
  • Pointer size: 131 Bytes
  • Size of remote file: 838 kB
demo_files/examples/horse-statue.png ADDED

Git LFS Details

  • SHA256: c9c00f726efe9490b02d4c232293b629e0146dad6ce1ff8e22da8102345c5fe9
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
demo_files/examples/penguin.png ADDED

Git LFS Details

  • SHA256: 7a1667d874e9379a8d36e676fb80327bd7b5d3673cb77d7d4cf27bb53408fb98
  • Pointer size: 131 Bytes
  • Size of remote file: 659 kB
demo_files/examples/pot.png ADDED

Git LFS Details

  • SHA256: 32d5d8c110646a46ca24a4d6994cb848ef79cc7ad78dcc7419be0e6f02476a86
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
demo_files/examples/raccoon_wizard.png ADDED

Git LFS Details

  • SHA256: 32cc3850d9f48548882c7b148e508e8ab149bc4f363611e9739adcbd38e8b16d
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
demo_files/examples/stylized-rocks.png ADDED

Git LFS Details

  • SHA256: 386c3be3a6f24ee52e13f130c1ebc02a1bc46eb2c0ebe90d79ce6f38751f0fc6
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
demo_files/hdri/abandoned_tiled_room_1k.hdr ADDED
Binary file (478 kB). View file
 
demo_files/hdri/metro_noord_1k.hdr ADDED
Binary file (467 kB). View file
 
demo_files/hdri/neon_photostudio_1k.hdr ADDED
Binary file (438 kB). View file
 
demo_files/hdri/peppermint_powerplant_1k.hdr ADDED
Binary file (473 kB). View file
 
demo_files/hdri/rainforest_trail_1k.hdr ADDED
Binary file (512 kB). View file
 
demo_files/hdri/studio_small_08_1k.hdr ADDED
Binary file (412 kB). View file
 
demo_files/hdri/urban_alley_01_1k.hdr ADDED
Binary file (458 kB). View file
 
demo_files/turntable.gif ADDED

Git LFS Details

  • SHA256: ffb5cfca3da84a569de41535781dfc6103834b99207136eb6cbf72d097799c6c
  • Pointer size: 132 Bytes
  • Size of remote file: 7.58 MB
demo_files/workflows/spar3d_example.json ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "last_node_id": 17,
3
+ "last_link_id": 18,
4
+ "nodes": [
5
+ {
6
+ "id": 10,
7
+ "type": "SPAR3DLoader",
8
+ "pos": [
9
+ 52.92446517944336,
10
+ 394.328369140625
11
+ ],
12
+ "size": [
13
+ 210,
14
+ 26
15
+ ],
16
+ "flags": {},
17
+ "order": 0,
18
+ "mode": 0,
19
+ "inputs": [],
20
+ "outputs": [
21
+ {
22
+ "name": "spar3d_model",
23
+ "type": "SPAR3D_MODEL",
24
+ "links": [
25
+ 10
26
+ ],
27
+ "slot_index": 0
28
+ }
29
+ ],
30
+ "properties": {
31
+ "Node name for S&R": "SPAR3DLoader"
32
+ },
33
+ "widgets_values": []
34
+ },
35
+ {
36
+ "id": 13,
37
+ "type": "LoadImage",
38
+ "pos": [
39
+ -43.437347412109375,
40
+ 482.89678955078125
41
+ ],
42
+ "size": [
43
+ 315,
44
+ 314
45
+ ],
46
+ "flags": {},
47
+ "order": 1,
48
+ "mode": 0,
49
+ "inputs": [],
50
+ "outputs": [
51
+ {
52
+ "name": "IMAGE",
53
+ "type": "IMAGE",
54
+ "links": [
55
+ 11
56
+ ],
57
+ "slot_index": 0
58
+ },
59
+ {
60
+ "name": "MASK",
61
+ "type": "MASK",
62
+ "links": [
63
+ 16
64
+ ],
65
+ "slot_index": 1
66
+ }
67
+ ],
68
+ "properties": {
69
+ "Node name for S&R": "LoadImage"
70
+ },
71
+ "widgets_values": [
72
+ "cat1.png",
73
+ "image"
74
+ ]
75
+ },
76
+ {
77
+ "id": 16,
78
+ "type": "InvertMask",
79
+ "pos": [
80
+ 377.1180419921875,
81
+ 605.384765625
82
+ ],
83
+ "size": [
84
+ 210,
85
+ 26
86
+ ],
87
+ "flags": {},
88
+ "order": 2,
89
+ "mode": 0,
90
+ "inputs": [
91
+ {
92
+ "name": "mask",
93
+ "type": "MASK",
94
+ "link": 16
95
+ }
96
+ ],
97
+ "outputs": [
98
+ {
99
+ "name": "MASK",
100
+ "type": "MASK",
101
+ "links": [
102
+ 17
103
+ ],
104
+ "slot_index": 0
105
+ }
106
+ ],
107
+ "properties": {
108
+ "Node name for S&R": "InvertMask"
109
+ },
110
+ "widgets_values": []
111
+ },
112
+ {
113
+ "id": 17,
114
+ "type": "SPAR3DSave",
115
+ "pos": [
116
+ 1133.669921875,
117
+ 439.6551513671875
118
+ ],
119
+ "size": [
120
+ 315,
121
+ 58
122
+ ],
123
+ "flags": {},
124
+ "order": 4,
125
+ "mode": 0,
126
+ "inputs": [
127
+ {
128
+ "name": "mesh",
129
+ "type": "MESH",
130
+ "link": 18
131
+ }
132
+ ],
133
+ "outputs": [],
134
+ "properties": {
135
+ "Node name for S&R": "SPAR3DSave"
136
+ },
137
+ "widgets_values": [
138
+ "SPAR3D"
139
+ ]
140
+ },
141
+ {
142
+ "id": 11,
143
+ "type": "SPAR3DSampler",
144
+ "pos": [
145
+ 673.0637817382812,
146
+ 441.2229309082031
147
+ ],
148
+ "size": [
149
+ 315,
150
+ 286
151
+ ],
152
+ "flags": {},
153
+ "order": 3,
154
+ "mode": 0,
155
+ "inputs": [
156
+ {
157
+ "name": "model",
158
+ "type": "SPAR3D_MODEL",
159
+ "link": 10
160
+ },
161
+ {
162
+ "name": "image",
163
+ "type": "IMAGE",
164
+ "link": 11
165
+ },
166
+ {
167
+ "name": "mask",
168
+ "type": "MASK",
169
+ "link": 17,
170
+ "shape": 7
171
+ },
172
+ {
173
+ "name": "pointcloud",
174
+ "type": "POINTCLOUD",
175
+ "link": null,
176
+ "shape": 7
177
+ }
178
+ ],
179
+ "outputs": [
180
+ {
181
+ "name": "mesh",
182
+ "type": "MESH",
183
+ "links": [
184
+ 18
185
+ ],
186
+ "slot_index": 0
187
+ },
188
+ {
189
+ "name": "pointcloud",
190
+ "type": "POINTCLOUD",
191
+ "links": null
192
+ }
193
+ ],
194
+ "properties": {
195
+ "Node name for S&R": "SPAR3DSampler"
196
+ },
197
+ "widgets_values": [
198
+ 1.3,
199
+ 1024,
200
+ "none",
201
+ 1000,
202
+ 3,
203
+ 3727502160,
204
+ "randomize",
205
+ "none"
206
+ ]
207
+ }
208
+ ],
209
+ "links": [
210
+ [
211
+ 10,
212
+ 10,
213
+ 0,
214
+ 11,
215
+ 0,
216
+ "SPAR3D_MODEL"
217
+ ],
218
+ [
219
+ 11,
220
+ 13,
221
+ 0,
222
+ 11,
223
+ 1,
224
+ "IMAGE"
225
+ ],
226
+ [
227
+ 16,
228
+ 13,
229
+ 1,
230
+ 16,
231
+ 0,
232
+ "MASK"
233
+ ],
234
+ [
235
+ 17,
236
+ 16,
237
+ 0,
238
+ 11,
239
+ 2,
240
+ "MASK"
241
+ ],
242
+ [
243
+ 18,
244
+ 11,
245
+ 0,
246
+ 17,
247
+ 0,
248
+ "MESH"
249
+ ]
250
+ ],
251
+ "groups": [],
252
+ "config": {},
253
+ "extra": {
254
+ "ds": {
255
+ "scale": 0.953502721998243,
256
+ "offset": [
257
+ 266.21995970220667,
258
+ 116.75398112171928
259
+ ]
260
+ }
261
+ },
262
+ "version": 0.4
263
+ }
gradio_app.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("pip install ./texture_baker/ ./uv_unwrapper/")
4
+
5
+ import random
6
+ import tempfile
7
+ import time
8
+ from contextlib import nullcontext
9
+ from functools import lru_cache
10
+ from typing import Any
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ import trimesh
16
+ from gradio_litmodel3d import LitModel3D
17
+ from gradio_pointcloudeditor import PointCloudEditor
18
+ from PIL import Image
19
+ from transparent_background import Remover
20
+
21
+ import spar3d.utils as spar3d_utils
22
+ from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
23
+ from spar3d.system import SPAR3D
24
+
25
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
26
+
27
+ bg_remover = Remover() # default setting
28
+
29
+ COND_WIDTH = 512
30
+ COND_HEIGHT = 512
31
+ COND_DISTANCE = 2.2
32
+ COND_FOVY = 0.591627
33
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
34
+
35
+ # Cached. Doesn't change
36
+ c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
37
+ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
38
+ COND_FOVY, COND_HEIGHT, COND_WIDTH
39
+ )
40
+
41
+ generated_files = []
42
+
43
+ # Delete previous gradio temp dir folder
44
+ if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
45
+ print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
46
+ import shutil
47
+
48
+ shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
49
+
50
+ device = spar3d_utils.get_device()
51
+
52
+ model = SPAR3D.from_pretrained(
53
+ "stabilityai/stable-point-aware-3d",
54
+ config_name="config.yaml",
55
+ weight_name="model.safetensors",
56
+ )
57
+ model.eval()
58
+ model = model.to(device)
59
+
60
+ example_files = [
61
+ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
62
+ ]
63
+
64
+
65
+ def forward_model(
66
+ batch,
67
+ system,
68
+ guidance_scale=3.0,
69
+ seed=0,
70
+ device="cuda",
71
+ remesh_option="none",
72
+ vertex_count=-1,
73
+ texture_resolution=1024,
74
+ ):
75
+ batch_size = batch["rgb_cond"].shape[0]
76
+
77
+ # prepare the condition for point cloud generation
78
+ # set seed
79
+ random.seed(seed)
80
+ torch.manual_seed(seed)
81
+ np.random.seed(seed)
82
+ cond_tokens = system.forward_pdiff_cond(batch)
83
+
84
+ if "pc_cond" not in batch:
85
+ sample_iter = system.sampler.sample_batch_progressive(
86
+ batch_size,
87
+ cond_tokens,
88
+ guidance_scale=guidance_scale,
89
+ device=device,
90
+ )
91
+ for x in sample_iter:
92
+ samples = x["xstart"]
93
+ batch["pc_cond"] = samples.permute(0, 2, 1).float()
94
+ batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
95
+
96
+ # subsample to the 512 points
97
+ batch["pc_cond"] = batch["pc_cond"][
98
+ :, torch.randperm(batch["pc_cond"].shape[1])[:512]
99
+ ]
100
+
101
+ # get the point cloud
102
+ xyz = batch["pc_cond"][0, :, :3].cpu().numpy()
103
+ color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
104
+ pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
105
+
106
+ # forward for the final mesh
107
+ trimesh_mesh, _glob_dict = model.generate_mesh(
108
+ batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count
109
+ )
110
+ trimesh_mesh = trimesh_mesh[0]
111
+
112
+ return trimesh_mesh, pc_rgb_trimesh
113
+
114
+
115
+ def run_model(
116
+ input_image,
117
+ guidance_scale,
118
+ random_seed,
119
+ pc_cond,
120
+ remesh_option,
121
+ vertex_count,
122
+ texture_resolution,
123
+ ):
124
+ start = time.time()
125
+ with torch.no_grad():
126
+ with (
127
+ torch.autocast(device_type=device, dtype=torch.float16)
128
+ if "cuda" in device
129
+ else nullcontext()
130
+ ):
131
+ model_batch = create_batch(input_image)
132
+ model_batch = {k: v.to(device) for k, v in model_batch.items()}
133
+
134
+ if pc_cond is not None:
135
+ # Check if pc_cond is a list
136
+ if isinstance(pc_cond, list):
137
+ cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6)
138
+ xyz = cond_tensor[:, :3]
139
+ color_rgb = cond_tensor[:, 3:]
140
+ elif isinstance(pc_cond, dict):
141
+ xyz = torch.tensor(pc_cond["positions"]).float().cuda()
142
+ color_rgb = torch.tensor(pc_cond["colors"]).float().cuda()
143
+ else:
144
+ xyz = torch.tensor(pc_cond.vertices).float().cuda()
145
+ color_rgb = (
146
+ torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0
147
+ )
148
+ model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(
149
+ 0
150
+ )
151
+ # sub-sample the point cloud to the target number of points
152
+ if model_batch["pc_cond"].shape[1] > 512:
153
+ idx = torch.randperm(model_batch["pc_cond"].shape[1])[:512]
154
+ model_batch["pc_cond"] = model_batch["pc_cond"][:, idx]
155
+ elif model_batch["pc_cond"].shape[1] < 512:
156
+ num_points = model_batch["pc_cond"].shape[1]
157
+ gr.Warning(
158
+ f"The uploaded point cloud should have at least 512 points. This point cloud only has {num_points}. Results may be worse."
159
+ )
160
+ pad = 512 - num_points
161
+ sampled_idx = torch.randint(
162
+ 0, model_batch["pc_cond"].shape[1], (pad,)
163
+ )
164
+ model_batch["pc_cond"] = torch.cat(
165
+ [
166
+ model_batch["pc_cond"],
167
+ model_batch["pc_cond"][:, sampled_idx],
168
+ ],
169
+ dim=1,
170
+ )
171
+
172
+ trimesh_mesh, trimesh_pc = forward_model(
173
+ model_batch,
174
+ model,
175
+ guidance_scale=guidance_scale,
176
+ seed=random_seed,
177
+ device="cuda",
178
+ remesh_option=remesh_option.lower(),
179
+ vertex_count=vertex_count,
180
+ texture_resolution=texture_resolution,
181
+ )
182
+
183
+ # Create new tmp file
184
+ temp_dir = tempfile.mkdtemp()
185
+ tmp_file = os.path.join(temp_dir, "mesh.glb")
186
+
187
+ trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
188
+ generated_files.append(tmp_file)
189
+
190
+ tmp_file_pc = os.path.join(temp_dir, "points.ply")
191
+ trimesh_pc.export(tmp_file_pc)
192
+ generated_files.append(tmp_file_pc)
193
+
194
+ print("Generation took:", time.time() - start, "s")
195
+
196
+ return tmp_file, tmp_file_pc, trimesh_pc
197
+
198
+
199
+ def create_batch(input_image: Image) -> dict[str, Any]:
200
+ img_cond = (
201
+ torch.from_numpy(
202
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
203
+ / 255.0
204
+ )
205
+ .float()
206
+ .clip(0, 1)
207
+ )
208
+ mask_cond = img_cond[:, :, -1:]
209
+ rgb_cond = torch.lerp(
210
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
211
+ )
212
+
213
+ batch_elem = {
214
+ "rgb_cond": rgb_cond,
215
+ "mask_cond": mask_cond,
216
+ "c2w_cond": c2w_cond.unsqueeze(0),
217
+ "intrinsic_cond": intrinsic.unsqueeze(0),
218
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
219
+ }
220
+ # Add batch dim
221
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
222
+ return batched
223
+
224
+
225
+ @lru_cache
226
+ def checkerboard(squares: int, size: int, min_value: float = 0.5):
227
+ base = np.zeros((squares, squares)) + min_value
228
+ base[1::2, ::2] = 1
229
+ base[::2, 1::2] = 1
230
+
231
+ repeat_mult = size // squares
232
+ return (
233
+ base.repeat(repeat_mult, axis=0)
234
+ .repeat(repeat_mult, axis=1)[:, :, None]
235
+ .repeat(3, axis=-1)
236
+ )
237
+
238
+
239
+ def remove_background(input_image: Image) -> Image:
240
+ return bg_remover.process(input_image.convert("RGB"))
241
+
242
+
243
+ def show_mask_img(input_image: Image) -> Image:
244
+ img_numpy = np.array(input_image)
245
+ alpha = img_numpy[:, :, 3] / 255.0
246
+ chkb = checkerboard(32, 512) * 255
247
+ new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
248
+ return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
249
+
250
+
251
+ def process_model_run(
252
+ background_state,
253
+ guidance_scale,
254
+ random_seed,
255
+ pc_cond,
256
+ remesh_option,
257
+ vertex_count_type,
258
+ vertex_count,
259
+ texture_resolution,
260
+ ):
261
+ # Adjust vertex count based on selection
262
+ final_vertex_count = (
263
+ -1
264
+ if vertex_count_type == "Keep Vertex Count"
265
+ else (
266
+ vertex_count // 2
267
+ if vertex_count_type == "Target Face Count"
268
+ else vertex_count
269
+ )
270
+ )
271
+ print(
272
+ f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
273
+ )
274
+
275
+ glb_file, pc_file, pc_plot = run_model(
276
+ background_state,
277
+ guidance_scale,
278
+ random_seed,
279
+ pc_cond,
280
+ remesh_option,
281
+ final_vertex_count,
282
+ texture_resolution,
283
+ )
284
+ # Create a single float list of x y z r g b
285
+ point_list = []
286
+ for i in range(pc_plot.vertices.shape[0]):
287
+ point_list.extend(
288
+ [
289
+ pc_plot.vertices[i, 0],
290
+ pc_plot.vertices[i, 1],
291
+ pc_plot.vertices[i, 2],
292
+ pc_plot.colors[i, 0] / 255.0,
293
+ pc_plot.colors[i, 1] / 255.0,
294
+ pc_plot.colors[i, 2] / 255.0,
295
+ ]
296
+ )
297
+
298
+ return glb_file, pc_file, point_list
299
+
300
+
301
+ def regenerate_run(
302
+ background_state,
303
+ guidance_scale,
304
+ random_seed,
305
+ pc_cond,
306
+ remesh_option,
307
+ vertex_count_type,
308
+ vertex_count,
309
+ texture_resolution,
310
+ ):
311
+ glb_file, pc_file, point_list = process_model_run(
312
+ background_state,
313
+ guidance_scale,
314
+ random_seed,
315
+ pc_cond,
316
+ remesh_option,
317
+ vertex_count_type,
318
+ vertex_count,
319
+ texture_resolution,
320
+ )
321
+ return (
322
+ gr.update(), # run_btn
323
+ gr.update(), # img_proc_state
324
+ gr.update(), # background_remove_state
325
+ gr.update(), # preview_removal
326
+ gr.update(value=glb_file, visible=True), # output_3d
327
+ gr.update(visible=True), # hdr_row
328
+ gr.update(visible=True), # point_cloud_row
329
+ gr.update(value=point_list), # point_cloud_editor
330
+ gr.update(value=pc_file), # pc_download
331
+ gr.update(visible=False), # regenerate_btn
332
+ )
333
+
334
+
335
+ def run_button(
336
+ run_btn,
337
+ input_image,
338
+ background_state,
339
+ foreground_ratio,
340
+ no_crop,
341
+ guidance_scale,
342
+ random_seed,
343
+ pc_upload,
344
+ pc_cond_file,
345
+ remesh_option,
346
+ vertex_count_type,
347
+ vertex_count,
348
+ texture_resolution,
349
+ ):
350
+ if run_btn == "Run":
351
+ if torch.cuda.is_available():
352
+ torch.cuda.reset_peak_memory_stats()
353
+
354
+ if pc_upload:
355
+ # make sure the pc_cond_file has been uploaded
356
+ try:
357
+ pc_cond = trimesh.load(pc_cond_file.name)
358
+ except Exception:
359
+ raise gr.Error(
360
+ "Please upload a valid point cloud ply file as condition."
361
+ )
362
+ else:
363
+ pc_cond = None
364
+
365
+ glb_file, pc_file, pc_list = process_model_run(
366
+ background_state,
367
+ guidance_scale,
368
+ random_seed,
369
+ pc_cond,
370
+ remesh_option,
371
+ vertex_count_type,
372
+ vertex_count,
373
+ texture_resolution,
374
+ )
375
+
376
+ if torch.cuda.is_available():
377
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
378
+ elif torch.backends.mps.is_available():
379
+ print(
380
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
381
+ )
382
+
383
+ return (
384
+ gr.update(), # run_btn
385
+ gr.update(), # img_proc_state
386
+ gr.update(), # background_remove_state
387
+ gr.update(), # preview_removal
388
+ gr.update(value=glb_file, visible=True), # output_3d
389
+ gr.update(visible=True), # hdr_row
390
+ gr.update(visible=True), # point_cloud_row
391
+ gr.update(value=pc_list), # point_cloud_editor
392
+ gr.update(value=pc_file), # pc_download
393
+ gr.update(visible=False), # regenerate_btn
394
+ )
395
+
396
+ elif run_btn == "Remove Background":
397
+ rem_removed = remove_background(input_image)
398
+
399
+ fr_res = spar3d_utils.foreground_crop(
400
+ rem_removed,
401
+ crop_ratio=foreground_ratio,
402
+ newsize=(COND_WIDTH, COND_HEIGHT),
403
+ no_crop=no_crop,
404
+ )
405
+
406
+ return (
407
+ gr.update(value="Run", visible=True), # run_btn
408
+ rem_removed, # img_proc_state,
409
+ fr_res, # background_remove_state
410
+ gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
411
+ gr.update(value=None, visible=False), # output_3d
412
+ gr.update(visible=False), # hdr_row
413
+ gr.update(visible=False), # point_cloud_row
414
+ gr.update(value=None), # point_cloud_editor
415
+ gr.update(value=None), # pc_download
416
+ gr.update(visible=False), # regenerate_btn
417
+ )
418
+
419
+
420
+ def requires_bg_remove(image, fr, no_crop):
421
+ if image is None:
422
+ return (
423
+ gr.update(visible=False, value="Run"), # run_Btn
424
+ None, # img_proc_state
425
+ None, # background_remove_state
426
+ gr.update(value=None, visible=False), # preview_removal
427
+ gr.update(value=None, visible=False), # output_3d
428
+ gr.update(visible=False), # hdr_row
429
+ gr.update(visible=False), # point_cloud_row
430
+ gr.update(value=None), # point_cloud_editor
431
+ gr.update(value=None), # pc_download
432
+ gr.update(visible=False), # regenerate_btn
433
+ )
434
+ alpha_channel = np.array(image.getchannel("A"))
435
+ min_alpha = alpha_channel.min()
436
+
437
+ if min_alpha == 0:
438
+ print("Already has alpha")
439
+ fr_res = spar3d_utils.foreground_crop(
440
+ image, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
441
+ )
442
+ return (
443
+ gr.update(value="Run", visible=True), # run_Btn
444
+ image, # img_proc_state
445
+ fr_res, # background_remove_state
446
+ gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
447
+ gr.update(value=None, visible=False), # output_3d
448
+ gr.update(visible=False), # hdr_row
449
+ gr.update(visible=False), # point_cloud_row
450
+ gr.update(value=None), # point_cloud_editor
451
+ gr.update(value=None), # pc_download
452
+ gr.update(visible=False), # regenerate_btn
453
+ )
454
+ return (
455
+ gr.update(value="Remove Background", visible=True), # run_Btn
456
+ None, # img_proc_state
457
+ None, # background_remove_state
458
+ gr.update(value=None, visible=False), # preview_removal
459
+ gr.update(value=None, visible=False), # output_3d
460
+ gr.update(visible=False), # hdr_row
461
+ gr.update(visible=False), # point_cloud_row
462
+ gr.update(value=None), # point_cloud_editor
463
+ gr.update(value=None), # pc_download
464
+ gr.update(visible=False), # regenerate_btn
465
+ )
466
+
467
+
468
+ def update_foreground_ratio(img_proc, fr, no_crop):
469
+ foreground_res = spar3d_utils.foreground_crop(
470
+ img_proc, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
471
+ )
472
+ return (
473
+ foreground_res,
474
+ gr.update(value=show_mask_img(foreground_res)),
475
+ )
476
+
477
+
478
+ def update_resolution_controls(remesh_choice, vertex_count_type):
479
+ show_controls = remesh_choice.lower() != "none"
480
+ show_vertex_count = vertex_count_type != "Keep Vertex Count"
481
+ return (
482
+ gr.update(visible=show_controls), # vertex_count_type
483
+ gr.update(visible=show_controls and show_vertex_count), # vertex_count_slider
484
+ )
485
+
486
+
487
+ with gr.Blocks() as demo:
488
+ img_proc_state = gr.State()
489
+ background_remove_state = gr.State()
490
+ gr.Markdown(
491
+ """
492
+ # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
493
+
494
+ SPAR3D is a state-of-the-art method for 3D mesh reconstruction from a single image. This demo allows you to upload an image and generate a 3D mesh model from it. A feature of SPAR3D is it generates point clouds as intermediate representation before producing the mesh. You can edit the point cloud to adjust the final mesh. We provide a simple point cloud editor in this demo, where you can drag, recolor and rescale the point clouds. If you have more advanced editing needs (e.g. box selection, duplication, local streching, etc.), you can download the point cloud and edit it in softwares such as MeshLab or Blender. The edited point cloud can then be uploaded to this demo to generate a new 3D model by checking the "Point cloud upload" box.
495
+
496
+ **Tips**
497
+
498
+ 1. If the image does not have a valid alpha channel, it will go through the background removal step. Our built-in background removal can be inaccurate sometimes, which will result in poor mesh quality. In such cases, you can use external background removal tools to obtain a RGBA image before uploading here.
499
+ 2. You can adjust the foreground ratio to control the size of the foreground object. This may have major impact on the final mesh.
500
+ 3. Guidance scale controls the strength of the image condition in the point cloud generation process. A higher value may result in higher mesh fidelity, but the variability by changing the random seed will be lower. Note that the guidance scale and the seed are not effective when the point cloud is manually uploaded.
501
+ 4. Our online editor supports multi-selection by holding down the shift key. This allows you to recolor multiple points at once.
502
+ 5. The editing should mainly alter the unseen parts of the object. Visible parts can be edited, but the edits should be consistent with the image. Editing the visible parts in a way that contradicts the image may result in poor mesh quality.
503
+ 6. You can upload your own HDR environment map to light the 3D model.
504
+ """
505
+ )
506
+ with gr.Row(variant="panel"):
507
+ with gr.Column():
508
+ with gr.Row():
509
+ input_img = gr.Image(
510
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
511
+ )
512
+ preview_removal = gr.Image(
513
+ label="Preview Background Removal",
514
+ type="pil",
515
+ image_mode="RGB",
516
+ interactive=False,
517
+ visible=False,
518
+ )
519
+
520
+ gr.Markdown("### Input Controls")
521
+ with gr.Group():
522
+ with gr.Row():
523
+ no_crop = gr.Checkbox(label="No cropping", value=False)
524
+ pc_upload = gr.Checkbox(label="Point cloud upload", value=False)
525
+
526
+ pc_cond_file = gr.File(
527
+ label="Point Cloud Upload",
528
+ file_types=[".ply"],
529
+ file_count="single",
530
+ visible=False,
531
+ )
532
+
533
+ foreground_ratio = gr.Slider(
534
+ label="Padding Ratio",
535
+ minimum=1.0,
536
+ maximum=2.0,
537
+ value=1.3,
538
+ step=0.05,
539
+ )
540
+
541
+ pc_upload.change(
542
+ lambda x: gr.update(visible=x),
543
+ inputs=pc_upload,
544
+ outputs=[pc_cond_file],
545
+ )
546
+
547
+ no_crop.change(
548
+ update_foreground_ratio,
549
+ inputs=[img_proc_state, foreground_ratio, no_crop],
550
+ outputs=[background_remove_state, preview_removal],
551
+ )
552
+
553
+ foreground_ratio.change(
554
+ update_foreground_ratio,
555
+ inputs=[img_proc_state, foreground_ratio, no_crop],
556
+ outputs=[background_remove_state, preview_removal],
557
+ )
558
+
559
+ gr.Markdown("### Point Diffusion Controls")
560
+ with gr.Group():
561
+ guidance_scale = gr.Slider(
562
+ label="Guidance Scale",
563
+ minimum=1.0,
564
+ maximum=10.0,
565
+ value=3.0,
566
+ step=1.0,
567
+ )
568
+
569
+ random_seed = gr.Slider(
570
+ label="Seed",
571
+ minimum=0,
572
+ maximum=10000,
573
+ value=0,
574
+ step=1,
575
+ )
576
+
577
+ no_remesh = not TRIANGLE_REMESH_AVAILABLE and not QUAD_REMESH_AVAILABLE
578
+ gr.Markdown(
579
+ "### Texture Controls"
580
+ if no_remesh
581
+ else "### Meshing and Texture Controls"
582
+ )
583
+ with gr.Group():
584
+ remesh_choices = ["None"]
585
+ if TRIANGLE_REMESH_AVAILABLE:
586
+ remesh_choices.append("Triangle")
587
+ if QUAD_REMESH_AVAILABLE:
588
+ remesh_choices.append("Quad")
589
+
590
+ remesh_option = gr.Radio(
591
+ choices=remesh_choices,
592
+ label="Remeshing",
593
+ value="None",
594
+ visible=not no_remesh,
595
+ )
596
+
597
+ vertex_count_type = gr.Radio(
598
+ choices=[
599
+ "Keep Vertex Count",
600
+ "Target Vertex Count",
601
+ "Target Face Count",
602
+ ],
603
+ label="Mesh Resolution Control",
604
+ value="Keep Vertex Count",
605
+ visible=False,
606
+ )
607
+
608
+ vertex_count_slider = gr.Slider(
609
+ label="Target Count",
610
+ minimum=0,
611
+ maximum=20000,
612
+ value=2000,
613
+ visible=False,
614
+ )
615
+
616
+ texture_size = gr.Slider(
617
+ label="Texture Size",
618
+ minimum=512,
619
+ maximum=2048,
620
+ value=1024,
621
+ step=256,
622
+ visible=True,
623
+ )
624
+
625
+ remesh_option.change(
626
+ update_resolution_controls,
627
+ inputs=[remesh_option, vertex_count_type],
628
+ outputs=[vertex_count_type, vertex_count_slider],
629
+ )
630
+
631
+ vertex_count_type.change(
632
+ update_resolution_controls,
633
+ inputs=[remesh_option, vertex_count_type],
634
+ outputs=[vertex_count_type, vertex_count_slider],
635
+ )
636
+
637
+ run_btn = gr.Button("Run", variant="primary", visible=False)
638
+
639
+ with gr.Column():
640
+ with gr.Group(visible=False) as point_cloud_row:
641
+ point_size_slider = gr.Slider(
642
+ label="Point Size",
643
+ minimum=0.01,
644
+ maximum=1.0,
645
+ value=0.2,
646
+ step=0.01,
647
+ )
648
+ point_cloud_editor = PointCloudEditor(
649
+ up_axis="Z",
650
+ forward_axis="X",
651
+ lock_scale_z=True,
652
+ lock_scale_y=True,
653
+ visible=True,
654
+ )
655
+
656
+ pc_download = gr.File(
657
+ label="Point Cloud Download",
658
+ file_types=[".ply"],
659
+ file_count="single",
660
+ )
661
+ point_size_slider.change(
662
+ fn=lambda x: gr.update(point_size=x),
663
+ inputs=point_size_slider,
664
+ outputs=point_cloud_editor,
665
+ )
666
+
667
+ regenerate_btn = gr.Button(
668
+ "Re-run with point cloud", variant="primary", visible=False
669
+ )
670
+
671
+ output_3d = LitModel3D(
672
+ label="3D Model",
673
+ visible=False,
674
+ clear_color=[0.0, 0.0, 0.0, 0.0],
675
+ tonemapping="aces",
676
+ contrast=1.0,
677
+ scale=1.0,
678
+ )
679
+ with gr.Column(visible=False, scale=1.0) as hdr_row:
680
+ gr.Markdown(
681
+ """## HDR Environment Map
682
+
683
+ Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
684
+ """
685
+ )
686
+
687
+ with gr.Row():
688
+ hdr_illumination_file = gr.File(
689
+ label="HDR Env Map",
690
+ file_types=[".hdr"],
691
+ file_count="single",
692
+ )
693
+ example_hdris = [
694
+ os.path.join("demo_files/hdri", f)
695
+ for f in os.listdir("demo_files/hdri")
696
+ ]
697
+ hdr_illumination_example = gr.Examples(
698
+ examples=example_hdris,
699
+ inputs=hdr_illumination_file,
700
+ )
701
+
702
+ hdr_illumination_file.change(
703
+ lambda x: gr.update(env_map=x.name if x is not None else None),
704
+ inputs=hdr_illumination_file,
705
+ outputs=[output_3d],
706
+ )
707
+
708
+ examples = gr.Examples(
709
+ examples=example_files, inputs=input_img, examples_per_page=11
710
+ )
711
+
712
+ input_img.change(
713
+ requires_bg_remove,
714
+ inputs=[input_img, foreground_ratio, no_crop],
715
+ outputs=[
716
+ run_btn,
717
+ img_proc_state,
718
+ background_remove_state,
719
+ preview_removal,
720
+ output_3d,
721
+ hdr_row,
722
+ point_cloud_row,
723
+ point_cloud_editor,
724
+ pc_download,
725
+ regenerate_btn,
726
+ ],
727
+ )
728
+
729
+ point_cloud_editor.edit(
730
+ fn=lambda _x: gr.update(visible=True),
731
+ inputs=point_cloud_editor,
732
+ outputs=regenerate_btn,
733
+ )
734
+
735
+ regenerate_btn.click(
736
+ regenerate_run,
737
+ inputs=[
738
+ background_remove_state,
739
+ guidance_scale,
740
+ random_seed,
741
+ point_cloud_editor,
742
+ remesh_option,
743
+ vertex_count_type,
744
+ vertex_count_slider,
745
+ texture_size,
746
+ ],
747
+ outputs=[
748
+ run_btn,
749
+ img_proc_state,
750
+ background_remove_state,
751
+ preview_removal,
752
+ output_3d,
753
+ hdr_row,
754
+ point_cloud_row,
755
+ point_cloud_editor,
756
+ pc_download,
757
+ regenerate_btn,
758
+ ],
759
+ )
760
+
761
+ run_btn.click(
762
+ run_button,
763
+ inputs=[
764
+ run_btn,
765
+ input_img,
766
+ background_remove_state,
767
+ foreground_ratio,
768
+ no_crop,
769
+ guidance_scale,
770
+ random_seed,
771
+ pc_upload,
772
+ pc_cond_file,
773
+ remesh_option,
774
+ vertex_count_type,
775
+ vertex_count_slider,
776
+ texture_size,
777
+ ],
778
+ outputs=[
779
+ run_btn,
780
+ img_proc_state,
781
+ background_remove_state,
782
+ preview_removal,
783
+ output_3d,
784
+ hdr_row,
785
+ point_cloud_row,
786
+ point_cloud_editor,
787
+ pc_download,
788
+ regenerate_btn,
789
+ ],
790
+ )
791
+
792
+ demo.queue().launch()
load/tets/160_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
+ size 15408790
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.7.0
2
+ jaxtyping==0.2.31
3
+ omegaconf==2.3.0
4
+ transformers==4.42.3
5
+ loralib==0.1.2
6
+ git+https://github.com/openai/CLIP.git
7
+ git+https://github.com/SunzeY/AlphaCLIP.git
8
+ trimesh==4.4.1
9
+ numpy==1.26.4
10
+ huggingface-hub==0.23.4
11
+ transparent-background==1.3.3
12
+ gradio==4.43.0
13
+ gradio-litmodel3d==0.0.1
14
+ gradio-pointcloudeditor==0.0.9
15
+ gpytoolbox==0.2.0
16
+ # ./texture_baker/
17
+ # ./uv_unwrapper/
ruff.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [lint]
2
+ ignore = ["F722", "F821"]
3
+ extend-select = ["I"]
run.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from contextlib import nullcontext
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from transparent_background import Remover
9
+
10
+ from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
11
+ from spar3d.system import SPAR3D
12
+ from spar3d.utils import foreground_crop, get_device, remove_background
13
+
14
+
15
+ def check_positive(value):
16
+ ivalue = int(value)
17
+ if ivalue <= 0:
18
+ raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
19
+ return ivalue
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "image", type=str, nargs="+", help="Path to input image(s) or folder."
26
+ )
27
+ parser.add_argument(
28
+ "--device",
29
+ default=get_device(),
30
+ type=str,
31
+ help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
32
+ )
33
+ parser.add_argument(
34
+ "--pretrained-model",
35
+ default="stabilityai/spar3d",
36
+ type=str,
37
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/spar3d'",
38
+ )
39
+ parser.add_argument(
40
+ "--foreground-ratio",
41
+ default=1.3,
42
+ type=float,
43
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
44
+ )
45
+ parser.add_argument(
46
+ "--output-dir",
47
+ default="output/",
48
+ type=str,
49
+ help="Output directory to save the results. Default: 'output/'",
50
+ )
51
+ parser.add_argument(
52
+ "--texture-resolution",
53
+ default=1024,
54
+ type=int,
55
+ help="Texture atlas resolution. Default: 1024",
56
+ )
57
+
58
+ remesh_choices = ["none"]
59
+ if TRIANGLE_REMESH_AVAILABLE:
60
+ remesh_choices.append("triangle")
61
+ if QUAD_REMESH_AVAILABLE:
62
+ remesh_choices.append("quad")
63
+ parser.add_argument(
64
+ "--remesh_option",
65
+ choices=remesh_choices,
66
+ default="none",
67
+ help="Remeshing option",
68
+ )
69
+ if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
70
+ parser.add_argument(
71
+ "--reduction_count_type",
72
+ choices=["keep", "vertex", "faces"],
73
+ default="keep",
74
+ help="Vertex count type",
75
+ )
76
+ parser.add_argument(
77
+ "--target_count",
78
+ type=check_positive,
79
+ help="Selected target count.",
80
+ default=2000,
81
+ )
82
+ parser.add_argument(
83
+ "--batch_size", default=1, type=int, help="Batch size for inference"
84
+ )
85
+ args = parser.parse_args()
86
+
87
+ # Ensure args.device contains cuda
88
+ devices = ["cuda", "mps", "cpu"]
89
+ if not any(args.device in device for device in devices):
90
+ raise ValueError("Invalid device. Use cuda, mps or cpu")
91
+
92
+ output_dir = args.output_dir
93
+ os.makedirs(output_dir, exist_ok=True)
94
+
95
+ device = args.device
96
+ if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
97
+ device = "cpu"
98
+
99
+ print("Device used: ", device)
100
+
101
+ model = SPAR3D.from_pretrained(
102
+ args.pretrained_model,
103
+ config_name="config.yaml",
104
+ weight_name="model.safetensors",
105
+ )
106
+ model.to(device)
107
+ model.eval()
108
+
109
+ bg_remover = Remover(device=device)
110
+ images = []
111
+ idx = 0
112
+ for image_path in args.image:
113
+
114
+ def handle_image(image_path, idx):
115
+ image = remove_background(
116
+ Image.open(image_path).convert("RGBA"), bg_remover
117
+ )
118
+ image = foreground_crop(image, args.foreground_ratio)
119
+ os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
120
+ image.save(os.path.join(output_dir, str(idx), "input.png"))
121
+ images.append(image)
122
+
123
+ if os.path.isdir(image_path):
124
+ image_paths = [
125
+ os.path.join(image_path, f)
126
+ for f in os.listdir(image_path)
127
+ if f.endswith((".png", ".jpg", ".jpeg"))
128
+ ]
129
+ for image_path in image_paths:
130
+ handle_image(image_path, idx)
131
+ idx += 1
132
+ else:
133
+ handle_image(image_path, idx)
134
+ idx += 1
135
+
136
+ vertex_count = (
137
+ -1
138
+ if args.reduction_count_type == "keep"
139
+ else (
140
+ args.target_count
141
+ if args.reduction_count_type == "vertex"
142
+ else args.target_count // 2
143
+ )
144
+ )
145
+
146
+ for i in tqdm(range(0, len(images), args.batch_size)):
147
+ image = images[i : i + args.batch_size]
148
+ if torch.cuda.is_available():
149
+ torch.cuda.reset_peak_memory_stats()
150
+ with torch.no_grad():
151
+ with (
152
+ torch.autocast(device_type=device, dtype=torch.float16)
153
+ if "cuda" in device
154
+ else nullcontext()
155
+ ):
156
+ mesh, glob_dict = model.run_image(
157
+ image,
158
+ bake_resolution=args.texture_resolution,
159
+ remesh=args.remesh_option,
160
+ vertex_count=args.target_vertex_count,
161
+ return_points=True,
162
+ )
163
+ if torch.cuda.is_available():
164
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
165
+ elif torch.backends.mps.is_available():
166
+ print(
167
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
168
+ )
169
+
170
+ if len(image) == 1:
171
+ out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
172
+ mesh.export(out_mesh_path, include_normals=True)
173
+ out_points_path = os.path.join(output_dir, str(i), "points.ply")
174
+ glob_dict["point_clouds"][0].export(out_points_path)
175
+ else:
176
+ for j in range(len(mesh)):
177
+ out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
178
+ mesh[j].export(out_mesh_path, include_normals=True)
179
+ out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
180
+ glob_dict["point_clouds"][j].export(out_points_path)
spar3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from spar3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
spar3d/models/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from: https://github.com/openai/point-e
3
+ # Licensed under the MIT License
4
+ # Copyright (c) 2022 OpenAI
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ # --------------------------------------------------------
25
+
26
+ import math
27
+ from typing import Any, Dict, Iterable, Optional, Sequence, Union
28
+
29
+ import numpy as np
30
+ import torch as th
31
+
32
+
33
+ def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
34
+ def sigmoid(x):
35
+ return 1 / (1 + np.exp(-x))
36
+
37
+ v_start = sigmoid(start / tau)
38
+ v_end = sigmoid(end / tau)
39
+ output = sigmoid((t * (end - start) + start) / tau)
40
+ output = (v_end - output) / (v_end - v_start)
41
+ return np.clip(output, clip_min, 1.0)
42
+
43
+
44
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
45
+ """
46
+ This is the deprecated API for creating beta schedules.
47
+
48
+ See get_named_beta_schedule() for the new library of schedules.
49
+ """
50
+ if beta_schedule == "linear":
51
+ betas = np.linspace(
52
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
53
+ )
54
+ else:
55
+ raise NotImplementedError(beta_schedule)
56
+ assert betas.shape == (num_diffusion_timesteps,)
57
+ return betas
58
+
59
+
60
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, exp_p=12):
61
+ """
62
+ Get a pre-defined beta schedule for the given name.
63
+
64
+ The beta schedule library consists of beta schedules which remain similar
65
+ in the limit of num_diffusion_timesteps.
66
+ Beta schedules may be added, but should not be removed or changed once
67
+ they are committed to maintain backwards compatibility.
68
+ """
69
+ if schedule_name == "linear":
70
+ # Linear schedule from Ho et al, extended to work for any number of
71
+ # diffusion steps.
72
+ scale = 1000 / num_diffusion_timesteps
73
+ return get_beta_schedule(
74
+ "linear",
75
+ beta_start=scale * 0.0001,
76
+ beta_end=scale * 0.02,
77
+ num_diffusion_timesteps=num_diffusion_timesteps,
78
+ )
79
+ elif schedule_name == "cosine":
80
+ return betas_for_alpha_bar(
81
+ num_diffusion_timesteps,
82
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
83
+ )
84
+ elif schedule_name == "sigmoid":
85
+ # Sigmoid schedule passed through betas_for_alpha_bar
86
+ return betas_for_alpha_bar(
87
+ num_diffusion_timesteps, lambda t: sigmoid_schedule(t)
88
+ )
89
+ else:
90
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
91
+
92
+
93
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
94
+ """
95
+ Create a beta schedule that discretizes the given alpha_t_bar function,
96
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
97
+
98
+ :param num_diffusion_timesteps: the number of betas to produce.
99
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
100
+ produces the cumulative product of (1-beta) up to that
101
+ part of the diffusion process.
102
+ :param max_beta: the maximum beta to use; use values lower than 1 to
103
+ prevent singularities.
104
+ """
105
+ betas = []
106
+ for i in range(num_diffusion_timesteps):
107
+ t1 = i / num_diffusion_timesteps
108
+ t2 = (i + 1) / num_diffusion_timesteps
109
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
110
+ return np.array(betas)
111
+
112
+
113
+ def space_timesteps(num_timesteps, section_counts):
114
+ """
115
+ Create a list of timesteps to use from an original diffusion process,
116
+ given the number of timesteps we want to take from equally-sized portions
117
+ of the original process.
118
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
119
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
120
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
121
+ :param num_timesteps: the number of diffusion steps in the original
122
+ process to divide up.
123
+ :param section_counts: either a list of numbers, or a string containing
124
+ comma-separated numbers, indicating the step count
125
+ per section. As a special case, use "ddimN" where N
126
+ is a number of steps to use the striding from the
127
+ DDIM paper.
128
+ :return: a set of diffusion steps from the original process to use.
129
+ """
130
+ if isinstance(section_counts, str):
131
+ if section_counts.startswith("ddim"):
132
+ desired_count = int(section_counts[len("ddim") :])
133
+ for i in range(1, num_timesteps):
134
+ if len(range(0, num_timesteps, i)) == desired_count:
135
+ return set(range(0, num_timesteps, i))
136
+ raise ValueError(
137
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
138
+ )
139
+ elif section_counts.startswith("exact"):
140
+ res = set(int(x) for x in section_counts[len("exact") :].split(","))
141
+ for x in res:
142
+ if x < 0 or x >= num_timesteps:
143
+ raise ValueError(f"timestep out of bounds: {x}")
144
+ return res
145
+ section_counts = [int(x) for x in section_counts.split(",")]
146
+ size_per = num_timesteps // len(section_counts)
147
+ extra = num_timesteps % len(section_counts)
148
+ start_idx = 0
149
+ all_steps = []
150
+ for i, section_count in enumerate(section_counts):
151
+ size = size_per + (1 if i < extra else 0)
152
+ if size < section_count:
153
+ raise ValueError(
154
+ f"cannot divide section of {size} steps into {section_count}"
155
+ )
156
+ if section_count <= 1:
157
+ frac_stride = 1
158
+ else:
159
+ frac_stride = (size - 1) / (section_count - 1)
160
+ cur_idx = 0.0
161
+ taken_steps = []
162
+ for _ in range(section_count):
163
+ taken_steps.append(start_idx + round(cur_idx))
164
+ cur_idx += frac_stride
165
+ all_steps += taken_steps
166
+ start_idx += size
167
+ return set(all_steps)
168
+
169
+
170
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
171
+ """Extract values from a 1-D numpy array for a batch of indices."""
172
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
173
+ while len(res.shape) < len(broadcast_shape):
174
+ res = res[..., None]
175
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
176
+
177
+
178
+ class GaussianDiffusion:
179
+ """
180
+ Utilities for sampling from Gaussian diffusion models.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ betas: Sequence[float],
187
+ model_mean_type: str,
188
+ model_var_type: str,
189
+ channel_scales: Optional[np.ndarray] = None,
190
+ channel_biases: Optional[np.ndarray] = None,
191
+ ):
192
+ self.model_mean_type = model_mean_type
193
+ self.model_var_type = model_var_type
194
+ self.channel_scales = channel_scales
195
+ self.channel_biases = channel_biases
196
+
197
+ # Use float64 for accuracy
198
+ betas = np.array(betas, dtype=np.float64)
199
+ self.betas = betas
200
+ assert len(betas.shape) == 1, "betas must be 1-D"
201
+ assert (betas > 0).all() and (betas <= 1).all()
202
+
203
+ self.num_timesteps = int(betas.shape[0])
204
+
205
+ alphas = 1.0 - betas
206
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
207
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
208
+
209
+ # calculations for diffusion q(x_t | x_{t-1}) and others
210
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
211
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
212
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
213
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
214
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
215
+ self.posterior_variance = (
216
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
217
+ )
218
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
219
+ self.posterior_log_variance_clipped = np.log(
220
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
221
+ )
222
+
223
+ self.posterior_mean_coef1 = (
224
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
225
+ )
226
+ self.posterior_mean_coef2 = (
227
+ (1.0 - self.alphas_cumprod_prev)
228
+ * np.sqrt(alphas)
229
+ / (1.0 - self.alphas_cumprod)
230
+ )
231
+
232
+ def scale_channels(self, x: th.Tensor) -> th.Tensor:
233
+ """Apply channel-wise scaling."""
234
+ if self.channel_scales is not None:
235
+ x = x * th.from_numpy(self.channel_scales).to(x).reshape(
236
+ [1, -1, *([1] * (len(x.shape) - 2))]
237
+ )
238
+ if self.channel_biases is not None:
239
+ x = x + th.from_numpy(self.channel_biases).to(x).reshape(
240
+ [1, -1, *([1] * (len(x.shape) - 2))]
241
+ )
242
+ return x
243
+
244
+ def unscale_channels(self, x: th.Tensor) -> th.Tensor:
245
+ """Remove channel-wise scaling."""
246
+ if self.channel_biases is not None:
247
+ x = x - th.from_numpy(self.channel_biases).to(x).reshape(
248
+ [1, -1, *([1] * (len(x.shape) - 2))]
249
+ )
250
+ if self.channel_scales is not None:
251
+ x = x / th.from_numpy(self.channel_scales).to(x).reshape(
252
+ [1, -1, *([1] * (len(x.shape) - 2))]
253
+ )
254
+ return x
255
+
256
+ def unscale_out_dict(
257
+ self, out: Dict[str, Union[th.Tensor, Any]]
258
+ ) -> Dict[str, Union[th.Tensor, Any]]:
259
+ return {
260
+ k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v)
261
+ for k, v in out.items()
262
+ }
263
+
264
+ def q_posterior_mean_variance(self, x_start, x_t, t):
265
+ """
266
+ Compute the mean and variance of the diffusion posterior:
267
+
268
+ q(x_{t-1} | x_t, x_0)
269
+
270
+ """
271
+ assert x_start.shape == x_t.shape
272
+ posterior_mean = (
273
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
274
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
275
+ )
276
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
277
+ posterior_log_variance_clipped = _extract_into_tensor(
278
+ self.posterior_log_variance_clipped, t, x_t.shape
279
+ )
280
+ assert (
281
+ posterior_mean.shape[0]
282
+ == posterior_variance.shape[0]
283
+ == posterior_log_variance_clipped.shape[0]
284
+ == x_start.shape[0]
285
+ )
286
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
287
+
288
+ def p_mean_variance(
289
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
290
+ ):
291
+ """
292
+ Apply the model to get p(x_{t-1} | x_t).
293
+ """
294
+ if model_kwargs is None:
295
+ model_kwargs = {}
296
+
297
+ B, C = x.shape[:2]
298
+ assert t.shape == (B,)
299
+
300
+ # Direct prediction of eps
301
+ model_output = model(x, t, **model_kwargs)
302
+ if isinstance(model_output, tuple):
303
+ model_output, prev_latent = model_output
304
+ model_kwargs["prev_latent"] = prev_latent
305
+
306
+ # Convert model output to mean and variance
307
+ model_variance, model_log_variance = {
308
+ # for fixedlarge, we set the initial (log-)variance like so
309
+ # to get a better decoder log likelihood.
310
+ "fixed_large": (
311
+ np.append(self.posterior_variance[1], self.betas[1:]),
312
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
313
+ ),
314
+ "fixed_small": (
315
+ self.posterior_variance,
316
+ self.posterior_log_variance_clipped,
317
+ ),
318
+ }[self.model_var_type]
319
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
320
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
321
+
322
+ def process_xstart(x):
323
+ if denoised_fn is not None:
324
+ x = denoised_fn(x)
325
+ if clip_denoised:
326
+ x = x.clamp(
327
+ -self.channel_scales[0] * 0.67, self.channel_scales[0] * 0.67
328
+ )
329
+ x[:, 3:] = x[:, 3:].clamp(
330
+ -self.channel_scales[3] * 0.5, self.channel_scales[3] * 0.5
331
+ )
332
+ return x
333
+ return x
334
+
335
+ if self.model_mean_type == "x_prev":
336
+ pred_xstart = process_xstart(
337
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
338
+ )
339
+ model_mean = model_output
340
+ elif self.model_mean_type in ["x_start", "epsilon"]:
341
+ if self.model_mean_type == "x_start":
342
+ pred_xstart = process_xstart(model_output)
343
+ else:
344
+ pred_xstart = process_xstart(
345
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
346
+ )
347
+ model_mean, _, _ = self.q_posterior_mean_variance(
348
+ x_start=pred_xstart, x_t=x, t=t
349
+ )
350
+ # print('p_mean_variance:', pred_xstart.min(), pred_xstart.max())
351
+ else:
352
+ raise NotImplementedError(self.model_mean_type)
353
+
354
+ assert (
355
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
356
+ )
357
+ return {
358
+ "mean": model_mean,
359
+ "variance": model_variance,
360
+ "log_variance": model_log_variance,
361
+ "pred_xstart": pred_xstart,
362
+ }
363
+
364
+ def _predict_xstart_from_eps(self, x_t, t, eps):
365
+ assert x_t.shape == eps.shape
366
+ return (
367
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
368
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
369
+ )
370
+
371
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
372
+ assert x_t.shape == xprev.shape
373
+ return ( # (xprev - coef2*x_t) / coef1
374
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
375
+ - _extract_into_tensor(
376
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
377
+ )
378
+ * x_t
379
+ )
380
+
381
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
382
+ return (
383
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
384
+ - pred_xstart
385
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
386
+
387
+ def ddim_sample_loop_progressive(
388
+ self,
389
+ model,
390
+ shape,
391
+ noise=None,
392
+ clip_denoised=True,
393
+ denoised_fn=None,
394
+ model_kwargs=None,
395
+ device=None,
396
+ progress=False,
397
+ eta=0.0,
398
+ ):
399
+ """
400
+ Use DDIM to sample from the model and yield intermediate samples.
401
+ """
402
+ if device is None:
403
+ device = next(model.parameters()).device
404
+ assert isinstance(shape, (tuple, list))
405
+ if noise is not None:
406
+ img = noise
407
+ else:
408
+ img = th.randn(*shape, device=device)
409
+
410
+ indices = list(range(self.num_timesteps))[::-1]
411
+
412
+ if progress:
413
+ from tqdm.auto import tqdm
414
+
415
+ indices = tqdm(indices)
416
+
417
+ for i in indices:
418
+ t = th.tensor([i] * shape[0], device=device)
419
+ with th.no_grad():
420
+ out = self.ddim_sample(
421
+ model,
422
+ img,
423
+ t,
424
+ clip_denoised=clip_denoised,
425
+ denoised_fn=denoised_fn,
426
+ model_kwargs=model_kwargs,
427
+ eta=eta,
428
+ )
429
+ yield self.unscale_out_dict(out)
430
+ img = out["sample"]
431
+
432
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
433
+ return (
434
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
435
+ - pred_xstart
436
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
437
+
438
+ def ddim_sample(
439
+ self,
440
+ model,
441
+ x,
442
+ t,
443
+ clip_denoised=True,
444
+ denoised_fn=None,
445
+ model_kwargs=None,
446
+ eta=0.0,
447
+ ):
448
+ """
449
+ Sample x_{t-1} from the model using DDIM.
450
+ """
451
+ out = self.p_mean_variance(
452
+ model,
453
+ x,
454
+ t,
455
+ clip_denoised=clip_denoised,
456
+ denoised_fn=denoised_fn,
457
+ model_kwargs=model_kwargs,
458
+ )
459
+
460
+ # Usually our model outputs epsilon, but we re-derive it
461
+ # in case we used x_start or x_prev prediction.
462
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
463
+
464
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
465
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
466
+ sigma = (
467
+ eta
468
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
469
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
470
+ )
471
+
472
+ # Equation 12.
473
+ noise = th.randn_like(x)
474
+ mean_pred = (
475
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
476
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
477
+ )
478
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
479
+ sample = mean_pred + nonzero_mask * sigma * noise
480
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
481
+
482
+
483
+ class SpacedDiffusion(GaussianDiffusion):
484
+ """
485
+ A diffusion process which can skip steps in a base diffusion process.
486
+ """
487
+
488
+ def __init__(self, use_timesteps: Iterable[int], **kwargs):
489
+ self.use_timesteps = set(use_timesteps)
490
+ self.timestep_map = []
491
+ self.original_num_steps = len(kwargs["betas"])
492
+
493
+ base_diffusion = GaussianDiffusion(**kwargs)
494
+ last_alpha_cumprod = 1.0
495
+ new_betas = []
496
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
497
+ if i in self.use_timesteps:
498
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
499
+ last_alpha_cumprod = alpha_cumprod
500
+ self.timestep_map.append(i)
501
+ kwargs["betas"] = np.array(new_betas)
502
+ super().__init__(**kwargs)
503
+
504
+ def p_mean_variance(self, model, *args, **kwargs):
505
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
506
+
507
+ def _wrap_model(self, model):
508
+ if isinstance(model, _WrappedModel):
509
+ return model
510
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
511
+
512
+
513
+ class _WrappedModel:
514
+ """Helper class to wrap models for SpacedDiffusion."""
515
+
516
+ def __init__(self, model, timestep_map, original_num_steps):
517
+ self.model = model
518
+ self.timestep_map = timestep_map
519
+ self.original_num_steps = original_num_steps
520
+
521
+ def __call__(self, x, ts, **kwargs):
522
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
523
+ new_ts = map_tensor[ts]
524
+ return self.model(x, new_ts, **kwargs)
spar3d/models/diffusion/sampler.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from: https://github.com/openai/point-e
3
+ # Licensed under the MIT License
4
+ # Copyright (c) 2022 OpenAI
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ # --------------------------------------------------------
25
+
26
+ from typing import Dict, Iterator
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from .gaussian_diffusion import GaussianDiffusion
32
+
33
+
34
+ class PointCloudSampler:
35
+ """
36
+ A wrapper around a model that produces conditional sample tensors.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: nn.Module,
42
+ diffusion: GaussianDiffusion,
43
+ num_points: int,
44
+ point_dim: int = 3,
45
+ guidance_scale: float = 3.0,
46
+ clip_denoised: bool = True,
47
+ sigma_min: float = 1e-3,
48
+ sigma_max: float = 120,
49
+ s_churn: float = 3,
50
+ ):
51
+ self.model = model
52
+ self.num_points = num_points
53
+ self.point_dim = point_dim
54
+ self.guidance_scale = guidance_scale
55
+ self.clip_denoised = clip_denoised
56
+ self.sigma_min = sigma_min
57
+ self.sigma_max = sigma_max
58
+ self.s_churn = s_churn
59
+
60
+ self.diffusion = diffusion
61
+
62
+ def sample_batch_progressive(
63
+ self,
64
+ batch_size: int,
65
+ condition: torch.Tensor,
66
+ noise=None,
67
+ device=None,
68
+ guidance_scale=None,
69
+ ) -> Iterator[Dict[str, torch.Tensor]]:
70
+ """
71
+ Generate samples progressively using classifier-free guidance.
72
+
73
+ Args:
74
+ batch_size: Number of samples to generate
75
+ condition: Conditioning tensor
76
+ noise: Optional initial noise tensor
77
+ device: Device to run on
78
+ guidance_scale: Optional override for guidance scale
79
+
80
+ Returns:
81
+ Iterator of dicts containing intermediate samples
82
+ """
83
+ if guidance_scale is None:
84
+ guidance_scale = self.guidance_scale
85
+
86
+ sample_shape = (batch_size, self.point_dim, self.num_points)
87
+
88
+ # Double the batch for classifier-free guidance
89
+ if guidance_scale != 1 and guidance_scale != 0:
90
+ condition = torch.cat([condition, torch.zeros_like(condition)], dim=0)
91
+ if noise is not None:
92
+ noise = torch.cat([noise, noise], dim=0)
93
+ model_kwargs = {"condition": condition}
94
+
95
+ internal_batch_size = batch_size
96
+ if guidance_scale != 1 and guidance_scale != 0:
97
+ model = self._uncond_guide_model(self.model, guidance_scale)
98
+ internal_batch_size *= 2
99
+ else:
100
+ model = self.model
101
+
102
+ samples_it = self.diffusion.ddim_sample_loop_progressive(
103
+ model,
104
+ shape=(internal_batch_size, *sample_shape[1:]),
105
+ model_kwargs=model_kwargs,
106
+ device=device,
107
+ clip_denoised=self.clip_denoised,
108
+ noise=noise,
109
+ )
110
+
111
+ for x in samples_it:
112
+ samples = {
113
+ "xstart": x["pred_xstart"][:batch_size],
114
+ "xprev": x["sample"][:batch_size] if "sample" in x else x["x"],
115
+ }
116
+ yield samples
117
+
118
+ def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module:
119
+ """
120
+ Wraps the model for classifier-free guidance.
121
+ """
122
+
123
+ def model_fn(x_t, ts, **kwargs):
124
+ half = x_t[: len(x_t) // 2]
125
+ combined = torch.cat([half, half], dim=0)
126
+ model_out = model(combined, ts, **kwargs)
127
+
128
+ eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :]
129
+ cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
130
+ half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
131
+ eps = torch.cat([half_eps, half_eps], dim=0)
132
+ return torch.cat([eps, rest], dim=1)
133
+
134
+ return model_fn
spar3d/models/global_estimator/reni_estimator.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.illumination.reni.env_map import RENIEnvMap
11
+ from spar3d.models.utils import BaseModule
12
+
13
+
14
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
15
+ assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"
16
+
17
+ def proj_u2a(u, a):
18
+ r"""
19
+ u: batch x 3
20
+ a: batch x 3
21
+ """
22
+ inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
23
+ norm2 = torch.sum(u**2, dim=-1, keepdim=True)
24
+ norm2 = torch.clamp(norm2, min=1e-8)
25
+ factor = inner_prod / (norm2 + 1e-10)
26
+ return factor * u
27
+
28
+ x_raw, y_raw = d6[..., :3], d6[..., 3:]
29
+
30
+ x = F.normalize(x_raw, dim=-1)
31
+ y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
32
+ z = torch.cross(x, y, dim=-1)
33
+
34
+ return torch.stack((x, y, z), dim=-1)
35
+
36
+
37
+ class ReniLatentCodeEstimator(BaseModule):
38
+ @dataclass
39
+ class Config(BaseModule.Config):
40
+ triplane_features: int = 40
41
+
42
+ n_layers: int = 5
43
+ hidden_features: int = 512
44
+ activation: str = "relu"
45
+
46
+ pool: str = "mean"
47
+
48
+ reni_env_config: dict = field(default_factory=dict)
49
+
50
+ cfg: Config
51
+
52
+ def configure(self):
53
+ layers = []
54
+ cur_features = self.cfg.triplane_features * 3
55
+ for _ in range(self.cfg.n_layers):
56
+ layers.append(
57
+ nn.Conv2d(
58
+ cur_features,
59
+ self.cfg.hidden_features,
60
+ kernel_size=3,
61
+ padding=0,
62
+ stride=2,
63
+ )
64
+ )
65
+ layers.append(self.make_activation(self.cfg.activation))
66
+
67
+ cur_features = self.cfg.hidden_features
68
+
69
+ self.layers = nn.Sequential(*layers)
70
+
71
+ self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
72
+ self.latent_dim = self.reni_env_map.field.latent_dim
73
+
74
+ self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
75
+ nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)
76
+
77
+ self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
78
+ nn.init.constant_(self.fc_rotations.bias, 0.0)
79
+ nn.init.normal_(
80
+ self.fc_rotations.weight, mean=0.0, std=0.01
81
+ ) # Small variance here
82
+
83
+ self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
84
+ nn.init.constant_(self.fc_scale.bias, 0.0)
85
+ nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01) # Small variance here
86
+
87
+ def make_activation(self, activation):
88
+ if activation == "relu":
89
+ return nn.ReLU(inplace=True)
90
+ elif activation == "silu":
91
+ return nn.SiLU(inplace=True)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ def forward(
96
+ self,
97
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
98
+ ) -> dict[str, Any]:
99
+ x = self.layers(
100
+ triplane.reshape(
101
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
102
+ )
103
+ )
104
+ x = x.mean(dim=[-2, -1])
105
+
106
+ latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
107
+ rotations = self.fc_rotations(x)
108
+ scale = self.fc_scale(x)
109
+
110
+ env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale)
111
+
112
+ return {"illumination": env_map["rgb"]}
spar3d/models/illumination/reni/components/film_siren.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/."""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ def kaiming_leaky_init(m):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Linear") != -1:
13
+ torch.nn.init.kaiming_normal_(
14
+ m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu"
15
+ )
16
+
17
+
18
+ def frequency_init(freq):
19
+ def init(m):
20
+ with torch.no_grad():
21
+ if isinstance(m, nn.Linear):
22
+ num_input = m.weight.size(-1)
23
+ m.weight.uniform_(
24
+ -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq
25
+ )
26
+
27
+ return init
28
+
29
+
30
+ def first_layer_film_sine_init(m):
31
+ with torch.no_grad():
32
+ if isinstance(m, nn.Linear):
33
+ num_input = m.weight.size(-1)
34
+ m.weight.uniform_(-1 / num_input, 1 / num_input)
35
+
36
+
37
+ class CustomMappingNetwork(nn.Module):
38
+ def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim):
39
+ super().__init__()
40
+
41
+ self.network = []
42
+
43
+ for _ in range(map_hidden_layers):
44
+ self.network.append(nn.Linear(in_features, map_hidden_dim))
45
+ self.network.append(nn.LeakyReLU(0.2, inplace=True))
46
+ in_features = map_hidden_dim
47
+
48
+ self.network.append(nn.Linear(map_hidden_dim, map_output_dim))
49
+
50
+ self.network = nn.Sequential(*self.network)
51
+
52
+ self.network.apply(kaiming_leaky_init)
53
+ with torch.no_grad():
54
+ self.network[-1].weight *= 0.25
55
+
56
+ def forward(self, z):
57
+ frequencies_offsets = self.network(z)
58
+ frequencies = frequencies_offsets[
59
+ ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor")
60
+ ]
61
+ phase_shifts = frequencies_offsets[
62
+ ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") :
63
+ ]
64
+
65
+ return frequencies, phase_shifts
66
+
67
+
68
+ class FiLMLayer(nn.Module):
69
+ def __init__(self, input_dim, hidden_dim):
70
+ super().__init__()
71
+ self.layer = nn.Linear(input_dim, hidden_dim)
72
+
73
+ def forward(self, x, freq, phase_shift):
74
+ x = self.layer(x)
75
+ freq = freq.expand_as(x)
76
+ phase_shift = phase_shift.expand_as(x)
77
+ return torch.sin(freq * x + phase_shift)
78
+
79
+
80
+ class FiLMSiren(nn.Module):
81
+ """FiLM Conditioned Siren network."""
82
+
83
+ def __init__(
84
+ self,
85
+ in_dim: int,
86
+ hidden_layers: int,
87
+ hidden_features: int,
88
+ mapping_network_in_dim: int,
89
+ mapping_network_layers: int,
90
+ mapping_network_features: int,
91
+ out_dim: int,
92
+ outermost_linear: bool = False,
93
+ out_activation: Optional[nn.Module] = None,
94
+ ) -> None:
95
+ super().__init__()
96
+ self.in_dim = in_dim
97
+ assert self.in_dim > 0
98
+ self.out_dim = out_dim if out_dim is not None else hidden_features
99
+ self.hidden_layers = hidden_layers
100
+ self.hidden_features = hidden_features
101
+ self.mapping_network_in_dim = mapping_network_in_dim
102
+ self.mapping_network_layers = mapping_network_layers
103
+ self.mapping_network_features = mapping_network_features
104
+ self.outermost_linear = outermost_linear
105
+ self.out_activation = out_activation
106
+
107
+ self.net = nn.ModuleList()
108
+
109
+ self.net.append(FiLMLayer(self.in_dim, self.hidden_features))
110
+
111
+ for _ in range(self.hidden_layers - 1):
112
+ self.net.append(FiLMLayer(self.hidden_features, self.hidden_features))
113
+
114
+ self.final_layer = None
115
+ if self.outermost_linear:
116
+ self.final_layer = nn.Linear(self.hidden_features, self.out_dim)
117
+ self.final_layer.apply(frequency_init(25))
118
+ else:
119
+ final_layer = FiLMLayer(self.hidden_features, self.out_dim)
120
+ self.net.append(final_layer)
121
+
122
+ self.mapping_network = CustomMappingNetwork(
123
+ in_features=self.mapping_network_in_dim,
124
+ map_hidden_layers=self.mapping_network_layers,
125
+ map_hidden_dim=self.mapping_network_features,
126
+ map_output_dim=(len(self.net)) * self.hidden_features * 2,
127
+ )
128
+
129
+ self.net.apply(frequency_init(25))
130
+ self.net[0].apply(first_layer_film_sine_init)
131
+
132
+ def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts):
133
+ """Get conditiional frequencies and phase shifts from mapping network."""
134
+ frequencies = frequencies * 15 + 30
135
+
136
+ for index, layer in enumerate(self.net):
137
+ start = index * self.hidden_features
138
+ end = (index + 1) * self.hidden_features
139
+ x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
140
+
141
+ x = self.final_layer(x) if self.final_layer is not None else x
142
+ output = self.out_activation(x) if self.out_activation is not None else x
143
+ return output
144
+
145
+ def forward(self, x, conditioning_input):
146
+ """Forward pass."""
147
+ frequencies, phase_shifts = self.mapping_network(conditioning_input)
148
+ return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)
spar3d/models/illumination/reni/components/siren.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Siren MLP https://www.vincentsitzmann.com/siren/"""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class SineLayer(nn.Module):
11
+ """
12
+ Sine layer for the SIREN network.
13
+ """
14
+
15
+ def __init__(
16
+ self, in_features, out_features, bias=True, is_first=False, omega_0=30.0
17
+ ):
18
+ super().__init__()
19
+ self.omega_0 = omega_0
20
+ self.is_first = is_first
21
+
22
+ self.in_features = in_features
23
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
24
+
25
+ self.init_weights()
26
+
27
+ def init_weights(self):
28
+ with torch.no_grad():
29
+ if self.is_first:
30
+ self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
31
+ else:
32
+ self.linear.weight.uniform_(
33
+ -np.sqrt(6 / self.in_features) / self.omega_0,
34
+ np.sqrt(6 / self.in_features) / self.omega_0,
35
+ )
36
+
37
+ def forward(self, x):
38
+ return torch.sin(self.omega_0 * self.linear(x))
39
+
40
+
41
+ class Siren(nn.Module):
42
+ """Siren network.
43
+
44
+ Args:
45
+ in_dim: Input layer dimension
46
+ num_layers: Number of network layers
47
+ layer_width: Width of each MLP layer
48
+ out_dim: Output layer dimension. Uses layer_width if None.
49
+ activation: intermediate layer activation function.
50
+ out_activation: output activation function.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_dim: int,
56
+ hidden_layers: int,
57
+ hidden_features: int,
58
+ out_dim: Optional[int] = None,
59
+ outermost_linear: bool = False,
60
+ first_omega_0: float = 30,
61
+ hidden_omega_0: float = 30,
62
+ out_activation: Optional[nn.Module] = None,
63
+ ) -> None:
64
+ super().__init__()
65
+ self.in_dim = in_dim
66
+ assert self.in_dim > 0
67
+ self.out_dim = out_dim if out_dim is not None else hidden_features
68
+ self.outermost_linear = outermost_linear
69
+ self.first_omega_0 = first_omega_0
70
+ self.hidden_omega_0 = hidden_omega_0
71
+ self.hidden_layers = hidden_layers
72
+ self.layer_width = hidden_features
73
+ self.out_activation = out_activation
74
+
75
+ self.net = []
76
+ self.net.append(
77
+ SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0)
78
+ )
79
+
80
+ for _ in range(hidden_layers):
81
+ self.net.append(
82
+ SineLayer(
83
+ hidden_features,
84
+ hidden_features,
85
+ is_first=False,
86
+ omega_0=hidden_omega_0,
87
+ )
88
+ )
89
+
90
+ if outermost_linear:
91
+ final_layer = nn.Linear(hidden_features, self.out_dim)
92
+
93
+ with torch.no_grad():
94
+ final_layer.weight.uniform_(
95
+ -np.sqrt(6 / hidden_features) / hidden_omega_0,
96
+ np.sqrt(6 / hidden_features) / hidden_omega_0,
97
+ )
98
+
99
+ self.net.append(final_layer)
100
+ else:
101
+ self.net.append(
102
+ SineLayer(
103
+ hidden_features,
104
+ self.out_dim,
105
+ is_first=False,
106
+ omega_0=hidden_omega_0,
107
+ )
108
+ )
109
+
110
+ if self.out_activation is not None:
111
+ self.net.append(self.out_activation)
112
+
113
+ self.net = nn.Sequential(*self.net)
114
+
115
+ def forward(self, model_input):
116
+ """Forward pass through the network"""
117
+ output = self.net(model_input)
118
+ return output
spar3d/models/illumination/reni/components/transformer_decoder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(
9
+ self,
10
+ direction_input_dim: int,
11
+ conditioning_input_dim: int,
12
+ latent_dim: int,
13
+ num_heads: int,
14
+ ):
15
+ """
16
+ Multi-Head Attention module.
17
+
18
+ Args:
19
+ direction_input_dim (int): The input dimension of the directional input.
20
+ conditioning_input_dim (int): The input dimension of the conditioning input.
21
+ latent_dim (int): The latent dimension of the module.
22
+ num_heads (int): The number of heads to use in the attention mechanism.
23
+ """
24
+ super().__init__()
25
+ assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
26
+ self.num_heads = num_heads
27
+ self.head_dim = latent_dim // num_heads
28
+ self.scale = self.head_dim**-0.5
29
+
30
+ self.query = nn.Linear(direction_input_dim, latent_dim)
31
+ self.key = nn.Linear(conditioning_input_dim, latent_dim)
32
+ self.value = nn.Linear(conditioning_input_dim, latent_dim)
33
+ self.fc_out = nn.Linear(latent_dim, latent_dim)
34
+
35
+ def forward(
36
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
37
+ ) -> torch.Tensor:
38
+ """
39
+ Forward pass of the Multi-Head Attention module.
40
+
41
+ Args:
42
+ query (torch.Tensor): The directional input tensor.
43
+ key (torch.Tensor): The conditioning input tensor for the keys.
44
+ value (torch.Tensor): The conditioning input tensor for the values.
45
+
46
+ Returns:
47
+ torch.Tensor: The output tensor of the Multi-Head Attention module.
48
+ """
49
+ batch_size = query.size(0)
50
+
51
+ Q = (
52
+ self.query(query)
53
+ .view(batch_size, -1, self.num_heads, self.head_dim)
54
+ .transpose(1, 2)
55
+ )
56
+ K = (
57
+ self.key(key)
58
+ .view(batch_size, -1, self.num_heads, self.head_dim)
59
+ .transpose(1, 2)
60
+ )
61
+ V = (
62
+ self.value(value)
63
+ .view(batch_size, -1, self.num_heads, self.head_dim)
64
+ .transpose(1, 2)
65
+ )
66
+
67
+ attention = (
68
+ torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
69
+ )
70
+ attention = torch.softmax(attention, dim=-1)
71
+
72
+ out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
73
+ out = (
74
+ out.transpose(1, 2)
75
+ .contiguous()
76
+ .view(batch_size, -1, self.num_heads * self.head_dim)
77
+ )
78
+
79
+ out = self.fc_out(out).squeeze(1)
80
+ return out
81
+
82
+
83
+ class AttentionLayer(nn.Module):
84
+ def __init__(
85
+ self,
86
+ direction_input_dim: int,
87
+ conditioning_input_dim: int,
88
+ latent_dim: int,
89
+ num_heads: int,
90
+ ):
91
+ """
92
+ Attention Layer module.
93
+
94
+ Args:
95
+ direction_input_dim (int): The input dimension of the directional input.
96
+ conditioning_input_dim (int): The input dimension of the conditioning input.
97
+ latent_dim (int): The latent dimension of the module.
98
+ num_heads (int): The number of heads to use in the attention mechanism.
99
+ """
100
+ super().__init__()
101
+ self.mha = MultiHeadAttention(
102
+ direction_input_dim, conditioning_input_dim, latent_dim, num_heads
103
+ )
104
+ self.norm1 = nn.LayerNorm(latent_dim)
105
+ self.norm2 = nn.LayerNorm(latent_dim)
106
+ self.fc = nn.Sequential(
107
+ nn.Linear(latent_dim, latent_dim),
108
+ nn.ReLU(),
109
+ nn.Linear(latent_dim, latent_dim),
110
+ )
111
+
112
+ def forward(
113
+ self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
114
+ ) -> torch.Tensor:
115
+ """
116
+ Forward pass of the Attention Layer module.
117
+
118
+ Args:
119
+ directional_input (torch.Tensor): The directional input tensor.
120
+ conditioning_input (torch.Tensor): The conditioning input tensor.
121
+
122
+ Returns:
123
+ torch.Tensor: The output tensor of the Attention Layer module.
124
+ """
125
+ attn_output = self.mha(
126
+ directional_input, conditioning_input, conditioning_input
127
+ )
128
+ out1 = self.norm1(attn_output + directional_input)
129
+ fc_output = self.fc(out1)
130
+ out2 = self.norm2(fc_output + out1)
131
+ return out2
132
+
133
+
134
+ class Decoder(nn.Module):
135
+ def __init__(
136
+ self,
137
+ in_dim: int,
138
+ conditioning_input_dim: int,
139
+ hidden_features: int,
140
+ num_heads: int,
141
+ num_layers: int,
142
+ out_activation: Optional[nn.Module],
143
+ ):
144
+ """
145
+ Decoder module.
146
+
147
+ Args:
148
+ in_dim (int): The input dimension of the module.
149
+ conditioning_input_dim (int): The input dimension of the conditioning input.
150
+ hidden_features (int): The number of hidden features in the module.
151
+ num_heads (int): The number of heads to use in the attention mechanism.
152
+ num_layers (int): The number of layers in the module.
153
+ out_activation (nn.Module): The activation function to use on the output tensor.
154
+ """
155
+ super().__init__()
156
+ self.residual_projection = nn.Linear(
157
+ in_dim, hidden_features
158
+ ) # projection for residual connection
159
+ self.layers = nn.ModuleList(
160
+ [
161
+ AttentionLayer(
162
+ hidden_features, conditioning_input_dim, hidden_features, num_heads
163
+ )
164
+ for i in range(num_layers)
165
+ ]
166
+ )
167
+ self.fc = nn.Linear(hidden_features, 3) # 3 for RGB
168
+ self.out_activation = out_activation
169
+
170
+ def forward(
171
+ self, x: torch.Tensor, conditioning_input: torch.Tensor
172
+ ) -> torch.Tensor:
173
+ """
174
+ Forward pass of the Decoder module.
175
+
176
+ Args:
177
+ x (torch.Tensor): The input tensor.
178
+ conditioning_input (torch.Tensor): The conditioning input tensor.
179
+
180
+ Returns:
181
+ torch.Tensor: The output tensor of the Decoder module.
182
+ """
183
+ x = self.residual_projection(x)
184
+ for layer in self.layers:
185
+ x = layer(x, conditioning_input)
186
+ x = self.fc(x)
187
+ if self.out_activation is not None:
188
+ x = self.out_activation(x)
189
+ return x
spar3d/models/illumination/reni/components/vn_layers.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Phil Wang
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ """All code taken from https://github.com/lucidrains/VN-transformer"""
24
+
25
+ from collections import namedtuple
26
+ from functools import wraps
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from einops import rearrange, reduce
31
+ from einops.layers.torch import Rearrange
32
+ from packaging import version
33
+ from torch import einsum, nn
34
+
35
+ # constants
36
+
37
+ FlashAttentionConfig = namedtuple(
38
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
39
+ )
40
+
41
+ # helpers
42
+
43
+
44
+ def exists(val):
45
+ return val is not None
46
+
47
+
48
+ def once(fn):
49
+ called = False
50
+
51
+ @wraps(fn)
52
+ def inner(x):
53
+ nonlocal called
54
+ if called:
55
+ return
56
+ called = True
57
+ return fn(x)
58
+
59
+ return inner
60
+
61
+
62
+ print_once = once(print)
63
+
64
+ # main class
65
+
66
+
67
+ class Attend(nn.Module):
68
+ def __init__(self, dropout=0.0, flash=False, l2_dist=False):
69
+ super().__init__()
70
+ assert not (
71
+ flash and l2_dist
72
+ ), "flash attention is not compatible with l2 distance"
73
+ self.l2_dist = l2_dist
74
+
75
+ self.dropout = dropout
76
+ self.attn_dropout = nn.Dropout(dropout)
77
+
78
+ self.flash = flash
79
+ assert not (
80
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
81
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
82
+
83
+ # determine efficient attention configs for cuda and cpu
84
+
85
+ self.cpu_config = FlashAttentionConfig(True, True, True)
86
+ self.cuda_config = None
87
+
88
+ if not torch.cuda.is_available() or not flash:
89
+ return
90
+
91
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
92
+
93
+ if device_properties.major == 8 and device_properties.minor == 0:
94
+ print_once(
95
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
96
+ )
97
+ self.cuda_config = FlashAttentionConfig(True, False, False)
98
+ else:
99
+ print_once(
100
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
101
+ )
102
+ self.cuda_config = FlashAttentionConfig(False, True, True)
103
+
104
+ def flash_attn(self, q, k, v, mask=None):
105
+ _, heads, q_len, _, _, is_cuda = (
106
+ *q.shape,
107
+ k.shape[-2],
108
+ q.is_cuda,
109
+ )
110
+
111
+ # Check if mask exists and expand to compatible shape
112
+ # The mask is B L, so it would have to be expanded to B H N L
113
+
114
+ if exists(mask):
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ )
131
+
132
+ return out
133
+
134
+ def forward(self, q, k, v, mask=None):
135
+ """
136
+ einstein notation
137
+ b - batch
138
+ h - heads
139
+ n, i, j - sequence length (base sequence length, source, target)
140
+ d - feature dimension
141
+ """
142
+ scale = q.shape[-1] ** -0.5
143
+
144
+ if exists(mask) and mask.ndim != 4:
145
+ mask = rearrange(mask, "b j -> b 1 1 j")
146
+
147
+ if self.flash:
148
+ return self.flash_attn(q, k, v, mask=mask)
149
+
150
+ # similarity
151
+
152
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
153
+
154
+ # l2 distance
155
+
156
+ if self.l2_dist:
157
+ # -cdist squared == (-q^2 + 2qk - k^2)
158
+ # so simply work off the qk above
159
+ q_squared = reduce(q**2, "b h i d -> b h i 1", "sum")
160
+ k_squared = reduce(k**2, "b h j d -> b h 1 j", "sum")
161
+ sim = sim * 2 - q_squared - k_squared
162
+
163
+ # key padding mask
164
+
165
+ if exists(mask):
166
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
167
+
168
+ # attention
169
+
170
+ attn = sim.softmax(dim=-1)
171
+ attn = self.attn_dropout(attn)
172
+
173
+ # aggregate values
174
+
175
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
176
+
177
+ return out
178
+
179
+
180
+ # helper
181
+
182
+
183
+ def exists(val): # noqa: F811
184
+ return val is not None
185
+
186
+
187
+ def default(val, d):
188
+ return val if exists(val) else d
189
+
190
+
191
+ def inner_dot_product(x, y, *, dim=-1, keepdim=True):
192
+ return (x * y).sum(dim=dim, keepdim=keepdim)
193
+
194
+
195
+ # layernorm
196
+
197
+
198
+ class LayerNorm(nn.Module):
199
+ def __init__(self, dim):
200
+ super().__init__()
201
+ self.gamma = nn.Parameter(torch.ones(dim))
202
+ self.register_buffer("beta", torch.zeros(dim))
203
+
204
+ def forward(self, x):
205
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
206
+
207
+
208
+ # equivariant modules
209
+
210
+
211
+ class VNLinear(nn.Module):
212
+ def __init__(self, dim_in, dim_out, bias_epsilon=0.0):
213
+ super().__init__()
214
+ self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
215
+
216
+ self.bias = None
217
+ self.bias_epsilon = bias_epsilon
218
+
219
+ # in this paper, they propose going for quasi-equivariance with a small bias, controllable with epsilon, which they claim lead to better stability and results
220
+
221
+ if bias_epsilon > 0.0:
222
+ self.bias = nn.Parameter(torch.randn(dim_out))
223
+
224
+ def forward(self, x):
225
+ out = einsum("... i c, o i -> ... o c", x, self.weight)
226
+
227
+ if exists(self.bias):
228
+ bias = F.normalize(self.bias, dim=-1) * self.bias_epsilon
229
+ out = out + rearrange(bias, "... -> ... 1")
230
+
231
+ return out
232
+
233
+
234
+ class VNReLU(nn.Module):
235
+ def __init__(self, dim, eps=1e-6):
236
+ super().__init__()
237
+ self.eps = eps
238
+ self.W = nn.Parameter(torch.randn(dim, dim))
239
+ self.U = nn.Parameter(torch.randn(dim, dim))
240
+
241
+ def forward(self, x):
242
+ q = einsum("... i c, o i -> ... o c", x, self.W)
243
+ k = einsum("... i c, o i -> ... o c", x, self.U)
244
+
245
+ qk = inner_dot_product(q, k)
246
+
247
+ k_norm = k.norm(dim=-1, keepdim=True).clamp(min=self.eps)
248
+ q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k
249
+
250
+ out = torch.where(qk >= 0.0, q, q_projected_on_k)
251
+
252
+ return out
253
+
254
+
255
+ class VNAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ dim,
259
+ dim_head=64,
260
+ heads=8,
261
+ dim_coor=3,
262
+ bias_epsilon=0.0,
263
+ l2_dist_attn=False,
264
+ flash=False,
265
+ num_latents=None, # setting this would enable perceiver-like cross attention from latents to sequence, with the latents derived from VNWeightedPool
266
+ ):
267
+ super().__init__()
268
+ assert not (
269
+ l2_dist_attn and flash
270
+ ), "l2 distance attention is not compatible with flash attention"
271
+
272
+ self.scale = (dim_coor * dim_head) ** -0.5
273
+ dim_inner = dim_head * heads
274
+ self.heads = heads
275
+
276
+ self.to_q_input = None
277
+ if exists(num_latents):
278
+ self.to_q_input = VNWeightedPool(
279
+ dim, num_pooled_tokens=num_latents, squeeze_out_pooled_dim=False
280
+ )
281
+
282
+ self.to_q = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
283
+ self.to_k = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
284
+ self.to_v = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
285
+ self.to_out = VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon)
286
+
287
+ if l2_dist_attn and not exists(num_latents):
288
+ # tied queries and keys for l2 distance attention, and not perceiver-like attention
289
+ self.to_k = self.to_q
290
+
291
+ self.attend = Attend(flash=flash, l2_dist=l2_dist_attn)
292
+
293
+ def forward(self, x, mask=None):
294
+ """
295
+ einstein notation
296
+ b - batch
297
+ n - sequence
298
+ h - heads
299
+ d - feature dimension (channels)
300
+ c - coordinate dimension (3 for 3d space)
301
+ i - source sequence dimension
302
+ j - target sequence dimension
303
+ """
304
+
305
+ c = x.shape[-1]
306
+
307
+ if exists(self.to_q_input):
308
+ q_input = self.to_q_input(x, mask=mask)
309
+ else:
310
+ q_input = x
311
+
312
+ q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
313
+ q, k, v = map(
314
+ lambda t: rearrange(t, "b n (h d) c -> b h n (d c)", h=self.heads),
315
+ (q, k, v),
316
+ )
317
+
318
+ out = self.attend(q, k, v, mask=mask)
319
+
320
+ out = rearrange(out, "b h n (d c) -> b n (h d) c", c=c)
321
+ return self.to_out(out)
322
+
323
+
324
+ def VNFeedForward(dim, mult=4, bias_epsilon=0.0):
325
+ dim_inner = int(dim * mult)
326
+ return nn.Sequential(
327
+ VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon),
328
+ VNReLU(dim_inner),
329
+ VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon),
330
+ )
331
+
332
+
333
+ class VNLayerNorm(nn.Module):
334
+ def __init__(self, dim, eps=1e-6):
335
+ super().__init__()
336
+ self.eps = eps
337
+ self.ln = LayerNorm(dim)
338
+
339
+ def forward(self, x):
340
+ norms = x.norm(dim=-1)
341
+ x = x / rearrange(norms.clamp(min=self.eps), "... -> ... 1")
342
+ ln_out = self.ln(norms)
343
+ return x * rearrange(ln_out, "... -> ... 1")
344
+
345
+
346
+ class VNWeightedPool(nn.Module):
347
+ def __init__(
348
+ self, dim, dim_out=None, num_pooled_tokens=1, squeeze_out_pooled_dim=True
349
+ ):
350
+ super().__init__()
351
+ dim_out = default(dim_out, dim)
352
+ self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))
353
+ self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim
354
+
355
+ def forward(self, x, mask=None):
356
+ if exists(mask):
357
+ mask = rearrange(mask, "b n -> b n 1 1")
358
+ x = x.masked_fill(~mask, 0.0)
359
+ numer = reduce(x, "b n d c -> b d c", "sum")
360
+ denom = mask.sum(dim=1)
361
+ mean_pooled = numer / denom.clamp(min=1e-6)
362
+ else:
363
+ mean_pooled = reduce(x, "b n d c -> b d c", "mean")
364
+
365
+ out = einsum("b d c, m d e -> b m e c", mean_pooled, self.weight)
366
+
367
+ if not self.squeeze_out_pooled_dim:
368
+ return out
369
+
370
+ out = rearrange(out, "b 1 d c -> b d c")
371
+ return out
372
+
373
+
374
+ # equivariant VN transformer encoder
375
+
376
+
377
+ class VNTransformerEncoder(nn.Module):
378
+ def __init__(
379
+ self,
380
+ dim,
381
+ *,
382
+ depth,
383
+ dim_head=64,
384
+ heads=8,
385
+ dim_coor=3,
386
+ ff_mult=4,
387
+ final_norm=False,
388
+ bias_epsilon=0.0,
389
+ l2_dist_attn=False,
390
+ flash_attn=False,
391
+ ):
392
+ super().__init__()
393
+ self.dim = dim
394
+ self.dim_coor = dim_coor
395
+
396
+ self.layers = nn.ModuleList([])
397
+
398
+ for _ in range(depth):
399
+ self.layers.append(
400
+ nn.ModuleList(
401
+ [
402
+ VNAttention(
403
+ dim=dim,
404
+ dim_head=dim_head,
405
+ heads=heads,
406
+ bias_epsilon=bias_epsilon,
407
+ l2_dist_attn=l2_dist_attn,
408
+ flash=flash_attn,
409
+ ),
410
+ VNLayerNorm(dim),
411
+ VNFeedForward(dim=dim, mult=ff_mult, bias_epsilon=bias_epsilon),
412
+ VNLayerNorm(dim),
413
+ ]
414
+ )
415
+ )
416
+
417
+ self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()
418
+
419
+ def forward(self, x, mask=None):
420
+ *_, d, c = x.shape
421
+
422
+ assert (
423
+ x.ndim == 4 and d == self.dim and c == self.dim_coor
424
+ ), "input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))"
425
+
426
+ for attn, attn_post_ln, ff, ff_post_ln in self.layers:
427
+ x = attn_post_ln(attn(x, mask=mask)) + x
428
+ x = ff_post_ln(ff(x)) + x
429
+
430
+ return self.norm(x)
431
+
432
+
433
+ # invariant layers
434
+
435
+
436
+ class VNInvariant(nn.Module):
437
+ def __init__(
438
+ self,
439
+ dim,
440
+ dim_coor=3,
441
+ ):
442
+ super().__init__()
443
+ self.mlp = nn.Sequential(
444
+ VNLinear(dim, dim_coor), VNReLU(dim_coor), Rearrange("... d e -> ... e d")
445
+ )
446
+
447
+ def forward(self, x):
448
+ return einsum("b n d i, b n i o -> b n o", x, self.mlp(x))
449
+
450
+
451
+ # main class
452
+
453
+
454
+ class VNTransformer(nn.Module):
455
+ def __init__(
456
+ self,
457
+ *,
458
+ dim,
459
+ depth,
460
+ num_tokens=None,
461
+ dim_feat=None,
462
+ dim_head=64,
463
+ heads=8,
464
+ dim_coor=3,
465
+ reduce_dim_out=True,
466
+ bias_epsilon=0.0,
467
+ l2_dist_attn=False,
468
+ flash_attn=False,
469
+ translation_equivariance=False,
470
+ translation_invariant=False,
471
+ ):
472
+ super().__init__()
473
+ self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
474
+
475
+ dim_feat = default(dim_feat, 0)
476
+ self.dim_feat = dim_feat
477
+ self.dim_coor_total = dim_coor + dim_feat
478
+
479
+ assert (int(translation_equivariance) + int(translation_invariant)) <= 1
480
+ self.translation_equivariance = translation_equivariance
481
+ self.translation_invariant = translation_invariant
482
+
483
+ self.vn_proj_in = nn.Sequential(
484
+ Rearrange("... c -> ... 1 c"), VNLinear(1, dim, bias_epsilon=bias_epsilon)
485
+ )
486
+
487
+ self.encoder = VNTransformerEncoder(
488
+ dim=dim,
489
+ depth=depth,
490
+ dim_head=dim_head,
491
+ heads=heads,
492
+ bias_epsilon=bias_epsilon,
493
+ dim_coor=self.dim_coor_total,
494
+ l2_dist_attn=l2_dist_attn,
495
+ flash_attn=flash_attn,
496
+ )
497
+
498
+ if reduce_dim_out:
499
+ self.vn_proj_out = nn.Sequential(
500
+ VNLayerNorm(dim),
501
+ VNLinear(dim, 1, bias_epsilon=bias_epsilon),
502
+ Rearrange("... 1 c -> ... c"),
503
+ )
504
+ else:
505
+ self.vn_proj_out = nn.Identity()
506
+
507
+ def forward(
508
+ self, coors, *, feats=None, mask=None, return_concatted_coors_and_feats=False
509
+ ):
510
+ if self.translation_equivariance or self.translation_invariant:
511
+ coors_mean = reduce(coors, "... c -> c", "mean")
512
+ coors = coors - coors_mean
513
+
514
+ x = coors # [batch, num_points, 3]
515
+
516
+ if exists(feats):
517
+ if feats.dtype == torch.long:
518
+ assert exists(
519
+ self.token_emb
520
+ ), "num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices"
521
+ feats = self.token_emb(feats)
522
+
523
+ assert (
524
+ feats.shape[-1] == self.dim_feat
525
+ ), f"dim_feat should be set to {feats.shape[-1]}"
526
+ x = torch.cat((x, feats), dim=-1) # [batch, num_points, 3 + dim_feat]
527
+
528
+ assert x.shape[-1] == self.dim_coor_total
529
+
530
+ x = self.vn_proj_in(x) # [batch, num_points, hidden_dim, 3 + dim_feat]
531
+ x = self.encoder(x, mask=mask) # [batch, num_points, hidden_dim, 3 + dim_feat]
532
+ x = self.vn_proj_out(x) # [batch, num_points, 3 + dim_feat]
533
+
534
+ coors_out, feats_out = (
535
+ x[..., :3],
536
+ x[..., 3:],
537
+ ) # [batch, num_points, 3], [batch, num_points, dim_feat]
538
+
539
+ if self.translation_equivariance:
540
+ coors_out = coors_out + coors_mean
541
+
542
+ if not exists(feats):
543
+ return coors_out
544
+
545
+ if return_concatted_coors_and_feats:
546
+ return torch.cat((coors_out, feats_out), dim=-1)
547
+
548
+ return coors_out, feats_out
spar3d/models/illumination/reni/env_map.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from spar3d.models.utils import BaseModule
9
+
10
+ from .field import RENIField
11
+
12
+
13
+ def _direction_from_coordinate(
14
+ coordinate: Float[Tensor, "*B 2"],
15
+ ) -> Float[Tensor, "*B 3"]:
16
+ # OpenGL Convention
17
+ # +X Right
18
+ # +Y Up
19
+ # +Z Backward
20
+
21
+ u, v = coordinate.unbind(-1)
22
+ theta = (2 * torch.pi * u) - torch.pi
23
+ phi = torch.pi * v
24
+
25
+ dir = torch.stack(
26
+ [
27
+ theta.sin() * phi.sin(),
28
+ phi.cos(),
29
+ -1 * theta.cos() * phi.sin(),
30
+ ],
31
+ -1,
32
+ )
33
+ return dir
34
+
35
+
36
+ def _get_sample_coordinates(
37
+ resolution: List[int], device: Optional[torch.device] = None
38
+ ) -> Float[Tensor, "H W 2"]:
39
+ return torch.stack(
40
+ torch.meshgrid(
41
+ (torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
42
+ (torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
43
+ indexing="xy",
44
+ ),
45
+ -1,
46
+ )
47
+
48
+
49
+ class RENIEnvMap(BaseModule):
50
+ @dataclass
51
+ class Config(BaseModule.Config):
52
+ reni_config: dict = field(default_factory=dict)
53
+ resolution: int = 128
54
+
55
+ cfg: Config
56
+
57
+ def configure(self):
58
+ self.field = RENIField(self.cfg.reni_config)
59
+ resolution = (self.cfg.resolution, self.cfg.resolution * 2)
60
+ sample_directions = _direction_from_coordinate(
61
+ _get_sample_coordinates(resolution)
62
+ )
63
+ self.img_shape = sample_directions.shape[:-1]
64
+
65
+ sample_directions_flat = sample_directions.view(-1, 3)
66
+ # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
67
+ sample_directions_flat = torch.stack(
68
+ [
69
+ sample_directions_flat[:, 0],
70
+ -sample_directions_flat[:, 2],
71
+ sample_directions_flat[:, 1],
72
+ ],
73
+ -1,
74
+ )
75
+ self.sample_directions = torch.nn.Parameter(
76
+ sample_directions_flat, requires_grad=False
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ latent_codes: Float[Tensor, "B latent_dim 3"],
82
+ rotation: Optional[Float[Tensor, "B 3 3"]] = None,
83
+ scale: Optional[Float[Tensor, "B"]] = None,
84
+ ) -> Dict[str, Tensor]:
85
+ return {
86
+ k: v.view(latent_codes.shape[0], *self.img_shape, -1)
87
+ for k, v in self.field(
88
+ self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
89
+ latent_codes,
90
+ rotation=rotation,
91
+ scale=scale,
92
+ ).items()
93
+ }
spar3d/models/illumination/reni/field.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The University of York. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Mark Boss
16
+
17
+ """RENI field"""
18
+
19
+ import contextlib
20
+ from dataclasses import dataclass
21
+ from typing import Dict, Literal, Optional
22
+
23
+ import torch
24
+ from einops.layers.torch import Rearrange
25
+ from jaxtyping import Float
26
+ from torch import Tensor, nn
27
+
28
+ from spar3d.models.network import get_activation_module, trunc_exp
29
+ from spar3d.models.utils import BaseModule
30
+
31
+ from .components.film_siren import FiLMSiren
32
+ from .components.siren import Siren
33
+ from .components.transformer_decoder import Decoder
34
+ from .components.vn_layers import VNInvariant, VNLinear
35
+
36
+ # from nerfstudio.cameras.rays import RaySamples
37
+
38
+
39
+ def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
40
+ """Computes the expected value of sin(y) where y ~ N(x_means, x_vars)
41
+
42
+ Args:
43
+ x_means: Mean values.
44
+ x_vars: Variance of values.
45
+
46
+ Returns:
47
+ torch.Tensor: The expected value of sin.
48
+ """
49
+
50
+ return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
51
+
52
+
53
+ class NeRFEncoding(torch.nn.Module):
54
+ """Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided.
55
+ Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.
56
+
57
+ Args:
58
+ in_dim: Input dimension of tensor
59
+ num_frequencies: Number of encoded frequencies per axis
60
+ min_freq_exp: Minimum frequency exponent
61
+ max_freq_exp: Maximum frequency exponent
62
+ include_input: Append the input coordinate to the encoding
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ in_dim: int,
68
+ num_frequencies: int,
69
+ min_freq_exp: float,
70
+ max_freq_exp: float,
71
+ include_input: bool = False,
72
+ off_axis: bool = False,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ self.in_dim = in_dim
77
+ self.num_frequencies = num_frequencies
78
+ self.min_freq = min_freq_exp
79
+ self.max_freq = max_freq_exp
80
+ self.include_input = include_input
81
+
82
+ self.off_axis = off_axis
83
+
84
+ self.P = torch.tensor(
85
+ [
86
+ [0.8506508, 0, 0.5257311],
87
+ [0.809017, 0.5, 0.309017],
88
+ [0.5257311, 0.8506508, 0],
89
+ [1, 0, 0],
90
+ [0.809017, 0.5, -0.309017],
91
+ [0.8506508, 0, -0.5257311],
92
+ [0.309017, 0.809017, -0.5],
93
+ [0, 0.5257311, -0.8506508],
94
+ [0.5, 0.309017, -0.809017],
95
+ [0, 1, 0],
96
+ [-0.5257311, 0.8506508, 0],
97
+ [-0.309017, 0.809017, -0.5],
98
+ [0, 0.5257311, 0.8506508],
99
+ [-0.309017, 0.809017, 0.5],
100
+ [0.309017, 0.809017, 0.5],
101
+ [0.5, 0.309017, 0.809017],
102
+ [0.5, -0.309017, 0.809017],
103
+ [0, 0, 1],
104
+ [-0.5, 0.309017, 0.809017],
105
+ [-0.809017, 0.5, 0.309017],
106
+ [-0.809017, 0.5, -0.309017],
107
+ ]
108
+ ).T
109
+
110
+ def get_out_dim(self) -> int:
111
+ if self.in_dim is None:
112
+ raise ValueError("Input dimension has not been set")
113
+ out_dim = self.in_dim * self.num_frequencies * 2
114
+
115
+ if self.off_axis:
116
+ out_dim = self.P.shape[1] * self.num_frequencies * 2
117
+
118
+ if self.include_input:
119
+ out_dim += self.in_dim
120
+ return out_dim
121
+
122
+ def forward(
123
+ self,
124
+ in_tensor: Float[Tensor, "*b input_dim"],
125
+ covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None,
126
+ ) -> Float[Tensor, "*b output_dim"]:
127
+ """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
128
+ in mip-NeRF.
129
+
130
+ Args:
131
+ in_tensor: For best performance, the input tensor should be between 0 and 1.
132
+ covs: Covariances of input points.
133
+ Returns:
134
+ Output values will be between -1 and 1
135
+ """
136
+ # TODO check scaling here but just comment it for now
137
+ # in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
138
+ freqs = 2 ** torch.linspace(
139
+ self.min_freq, self.max_freq, self.num_frequencies
140
+ ).to(in_tensor.device)
141
+ # freqs = 2 ** (
142
+ # torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq
143
+ # ).to(in_tensor.device)
144
+ # freqs = 2 ** (
145
+ # torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq
146
+ # )
147
+
148
+ if self.off_axis:
149
+ scaled_inputs = (
150
+ torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs
151
+ )
152
+ else:
153
+ scaled_inputs = (
154
+ in_tensor[..., None] * freqs
155
+ ) # [..., "input_dim", "num_scales"]
156
+ scaled_inputs = scaled_inputs.view(
157
+ *scaled_inputs.shape[:-2], -1
158
+ ) # [..., "input_dim" * "num_scales"]
159
+
160
+ if covs is None:
161
+ encoded_inputs = torch.sin(
162
+ torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)
163
+ )
164
+ else:
165
+ input_var = (
166
+ torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None]
167
+ * freqs[None, :] ** 2
168
+ )
169
+ input_var = input_var.reshape((*input_var.shape[:-2], -1))
170
+ encoded_inputs = expected_sin(
171
+ torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1),
172
+ torch.cat(2 * [input_var], dim=-1),
173
+ )
174
+
175
+ if self.include_input:
176
+ encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)
177
+ return encoded_inputs
178
+
179
+
180
+ class RENIField(BaseModule):
181
+ @dataclass
182
+ class Config(BaseModule.Config):
183
+ """Configuration for model instantiation"""
184
+
185
+ fixed_decoder: bool = False
186
+ """Whether to fix the decoder weights"""
187
+ equivariance: str = "SO2"
188
+ """Type of equivariance to use: None, SO2, SO3"""
189
+ axis_of_invariance: str = "y"
190
+ """Which axis should SO2 equivariance be invariant to: x, y, z"""
191
+ invariant_function: str = "GramMatrix"
192
+ """Type of invariant function to use: GramMatrix, VN"""
193
+ conditioning: str = "Concat"
194
+ """Type of conditioning to use: FiLM, Concat, Attention"""
195
+ positional_encoding: str = "NeRF"
196
+ """Type of positional encoding to use. Currently only NeRF is supported"""
197
+ encoded_input: str = "Directions"
198
+ """Type of input to encode: None, Directions, Conditioning, Both"""
199
+ latent_dim: int = 36
200
+ """Dimensionality of latent code, N for a latent code size of (N x 3)"""
201
+ hidden_layers: int = 3
202
+ """Number of hidden layers"""
203
+ hidden_features: int = 128
204
+ """Number of hidden features"""
205
+ mapping_layers: int = 3
206
+ """Number of mapping layers"""
207
+ mapping_features: int = 128
208
+ """Number of mapping features"""
209
+ num_attention_heads: int = 8
210
+ """Number of attention heads"""
211
+ num_attention_layers: int = 3
212
+ """Number of attention layers"""
213
+ out_features: int = 3 # RGB
214
+ """Number of output features"""
215
+ last_layer_linear: bool = False
216
+ """Whether to use a linear layer as the last layer"""
217
+ output_activation: str = "exp"
218
+ """Activation function for output layer: sigmoid, tanh, relu, exp, None"""
219
+ first_omega_0: float = 30.0
220
+ """Omega_0 for first layer"""
221
+ hidden_omega_0: float = 30.0
222
+ """Omega_0 for hidden layers"""
223
+ fixed_decoder: bool = False
224
+ """Whether to fix the decoder weights"""
225
+ old_implementation: bool = False
226
+ """Whether to match implementation of old RENI, when using old checkpoints"""
227
+
228
+ cfg: Config
229
+
230
+ def configure(self):
231
+ self.equivariance = self.cfg.equivariance
232
+ self.conditioning = self.cfg.conditioning
233
+ self.latent_dim = self.cfg.latent_dim
234
+ self.hidden_layers = self.cfg.hidden_layers
235
+ self.hidden_features = self.cfg.hidden_features
236
+ self.mapping_layers = self.cfg.mapping_layers
237
+ self.mapping_features = self.cfg.mapping_features
238
+ self.out_features = self.cfg.out_features
239
+ self.last_layer_linear = self.cfg.last_layer_linear
240
+ self.output_activation = self.cfg.output_activation
241
+ self.first_omega_0 = self.cfg.first_omega_0
242
+ self.hidden_omega_0 = self.cfg.hidden_omega_0
243
+ self.old_implementation = self.cfg.old_implementation
244
+ self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance)
245
+
246
+ self.fixed_decoder = self.cfg.fixed_decoder
247
+ if self.cfg.invariant_function == "GramMatrix":
248
+ self.invariant_function = self.gram_matrix_invariance
249
+ else:
250
+ self.vn_proj_in = nn.Sequential(
251
+ Rearrange("... c -> ... 1 c"),
252
+ VNLinear(dim_in=1, dim_out=1, bias_epsilon=0),
253
+ )
254
+ dim_coor = 2 if self.cfg.equivariance == "SO2" else 3
255
+ self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor)
256
+ self.invariant_function = self.vn_invariance
257
+
258
+ self.network = self.setup_network()
259
+
260
+ if self.fixed_decoder:
261
+ for param in self.network.parameters():
262
+ param.requires_grad = False
263
+
264
+ if self.cfg.invariant_function == "VN":
265
+ for param in self.vn_proj_in.parameters():
266
+ param.requires_grad = False
267
+ for param in self.vn_invar.parameters():
268
+ param.requires_grad = False
269
+
270
+ @contextlib.contextmanager
271
+ def hold_decoder_fixed(self):
272
+ """Context manager to fix the decoder weights
273
+
274
+ Example usage:
275
+ ```
276
+ with instance_of_RENIField.hold_decoder_fixed():
277
+ # do stuff
278
+ ```
279
+ """
280
+ prev_state_network = {
281
+ name: p.requires_grad for name, p in self.network.named_parameters()
282
+ }
283
+ for param in self.network.parameters():
284
+ param.requires_grad = False
285
+ if self.cfg.invariant_function == "VN":
286
+ prev_state_proj_in = {
287
+ k: p.requires_grad for k, p in self.vn_proj_in.named_parameters()
288
+ }
289
+ prev_state_invar = {
290
+ k: p.requires_grad for k, p in self.vn_invar.named_parameters()
291
+ }
292
+ for param in self.vn_proj_in.parameters():
293
+ param.requires_grad = False
294
+ for param in self.vn_invar.parameters():
295
+ param.requires_grad = False
296
+
297
+ prev_decoder_state = self.fixed_decoder
298
+ self.fixed_decoder = True
299
+ try:
300
+ yield
301
+ finally:
302
+ # Restore the previous requires_grad state
303
+ for name, param in self.network.named_parameters():
304
+ param.requires_grad = prev_state_network[name]
305
+ if self.cfg.invariant_function == "VN":
306
+ for name, param in self.vn_proj_in.named_parameters():
307
+ param.requires_grad_(prev_state_proj_in[name])
308
+ for name, param in self.vn_invar.named_parameters():
309
+ param.requires_grad_(prev_state_invar[name])
310
+ self.fixed_decoder = prev_decoder_state
311
+
312
+ def vn_invariance(
313
+ self,
314
+ Z: Float[Tensor, "B latent_dim 3"],
315
+ D: Float[Tensor, "B num_rays 3"],
316
+ equivariance: Literal["None", "SO2", "SO3"] = "SO2",
317
+ axis_of_invariance: int = 1,
318
+ ):
319
+ """Generates a batched invariant representation from latent code Z and direction coordinates D.
320
+
321
+ Args:
322
+ Z: [B, latent_dim, 3] - Latent code.
323
+ D: [B num_rays, 3] - Direction coordinates.
324
+ equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'.
325
+ axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
326
+
327
+ Returns:
328
+ Tuple[Tensor, Tensor]: directional_input, conditioning_input
329
+ """
330
+ assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
331
+ other_axes = [i for i in range(3) if i != axis_of_invariance]
332
+
333
+ B, latent_dim, _ = Z.shape
334
+ _, num_rays, _ = D.shape
335
+
336
+ if equivariance == "None":
337
+ # get inner product between latent code and direction coordinates
338
+ innerprod = torch.sum(
339
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
340
+ ) # [B, num_rays, latent_dim]
341
+ z_input = (
342
+ Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
343
+ ) # [B, num_rays, latent_dim * 3]
344
+ return innerprod, z_input
345
+
346
+ if equivariance == "SO2":
347
+ z_other = torch.stack(
348
+ (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
349
+ ) # [B, latent_dim, 2]
350
+ d_other = torch.stack(
351
+ (D[..., other_axes[0]], D[..., other_axes[1]]), -1
352
+ ).unsqueeze(2) # [B, num_rays, 1, 2]
353
+ d_other = d_other.expand(
354
+ B, num_rays, latent_dim, 2
355
+ ) # [B, num_rays, latent_dim, 2]
356
+
357
+ z_other_emb = self.vn_proj_in(z_other) # [B, latent_dim, 1, 2]
358
+ z_other_invar = self.vn_invar(z_other_emb) # [B, latent_dim, 2]
359
+
360
+ # Get invariant component of Z along the axis of invariance
361
+ z_invar = Z[..., axis_of_invariance].unsqueeze(-1) # [B, latent_dim, 1]
362
+
363
+ # Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance.
364
+ # This encodes the rotational information. This is rotation-equivariant to rotations of either Z
365
+ # or D and is invariant to rotations of both Z and D.
366
+ innerprod = (z_other.unsqueeze(1) * d_other).sum(
367
+ dim=-1
368
+ ) # [B, num_rays, latent_dim]
369
+
370
+ # Compute norm along the axes orthogonal to the axis of invariance
371
+ d_other_norm = torch.sqrt(
372
+ D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
373
+ ).unsqueeze(-1) # [B num_rays, 1]
374
+
375
+ # Get invariant component of D along the axis of invariance
376
+ d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
377
+
378
+ directional_input = torch.cat(
379
+ (innerprod, d_invar, d_other_norm), -1
380
+ ) # [B, num_rays, latent_dim + 2]
381
+ conditioning_input = (
382
+ torch.cat((z_other_invar, z_invar), dim=-1)
383
+ .flatten(1)
384
+ .unsqueeze(1)
385
+ .expand(B, num_rays, latent_dim * 3)
386
+ ) # [B, num_rays, latent_dim * 3]
387
+
388
+ return directional_input, conditioning_input
389
+
390
+ if equivariance == "SO3":
391
+ z = self.vn_proj_in(Z) # [B, latent_dim, 1, 3]
392
+ z_invar = self.vn_invar(z) # [B, latent_dim, 3]
393
+ conditioning_input = (
394
+ z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim)
395
+ ) # [B, num_rays, latent_dim * 3]
396
+ # D [B, num_rays, 3] -> [B, num_rays, 1, 3]
397
+ # Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3]
398
+ innerprod = torch.sum(
399
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
400
+ ) # [B, num_rays, latent_dim]
401
+ return innerprod, conditioning_input
402
+
403
+ def gram_matrix_invariance(
404
+ self,
405
+ Z: Float[Tensor, "B latent_dim 3"],
406
+ D: Float[Tensor, "B num_rays 3"],
407
+ equivariance: Literal["None", "SO2", "SO3"] = "SO2",
408
+ axis_of_invariance: int = 1,
409
+ ):
410
+ """Generates an invariant representation from latent code Z and direction coordinates D.
411
+
412
+ Args:
413
+ Z (torch.Tensor): Latent code (B x latent_dim x 3)
414
+ D (torch.Tensor): Direction coordinates (B x num_rays x 3)
415
+ equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3'
416
+ axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
417
+ Default is 1 (y-axis).
418
+ Returns:
419
+ torch.Tensor: Invariant representation
420
+ """
421
+ assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
422
+ other_axes = [i for i in range(3) if i != axis_of_invariance]
423
+
424
+ B, latent_dim, _ = Z.shape
425
+ _, num_rays, _ = D.shape
426
+
427
+ if equivariance == "None":
428
+ # get inner product between latent code and direction coordinates
429
+ innerprod = torch.sum(
430
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
431
+ ) # [B, num_rays, latent_dim]
432
+ z_input = (
433
+ Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
434
+ ) # [B, num_rays, latent_dim * 3]
435
+ return innerprod, z_input
436
+
437
+ if equivariance == "SO2":
438
+ # Select components along axes orthogonal to the axis of invariance
439
+ z_other = torch.stack(
440
+ (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
441
+ ) # [B, latent_dim, 2]
442
+ d_other = torch.stack(
443
+ (D[..., other_axes[0]], D[..., other_axes[1]]), -1
444
+ ).unsqueeze(2) # [B, num_rays, 1, 2]
445
+ d_other = d_other.expand(
446
+ B, num_rays, latent_dim, 2
447
+ ) # size becomes [B, num_rays, latent_dim, 2]
448
+
449
+ # Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim
450
+ G = torch.bmm(z_other, torch.transpose(z_other, 1, 2))
451
+
452
+ # Flatten G to be size B x latent_dim^2
453
+ z_other_invar = G.flatten(start_dim=1)
454
+
455
+ # Get invariant component of Z along the axis of invariance
456
+ z_invar = Z[..., axis_of_invariance] # [B, latent_dim]
457
+
458
+ # Innerprod is size num_rays x latent_dim
459
+ innerprod = (z_other.unsqueeze(1) * d_other).sum(
460
+ dim=-1
461
+ ) # [B, num_rays, latent_dim]
462
+
463
+ # Compute norm along the axes orthogonal to the axis of invariance
464
+ d_other_norm = torch.sqrt(
465
+ D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
466
+ ).unsqueeze(-1) # [B, num_rays, 1]
467
+
468
+ # Get invariant component of D along the axis of invariance
469
+ d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
470
+
471
+ if not self.old_implementation:
472
+ directional_input = torch.cat(
473
+ (innerprod, d_invar, d_other_norm), -1
474
+ ) # [B, num_rays, latent_dim + 2]
475
+ conditioning_input = (
476
+ torch.cat((z_other_invar, z_invar), -1)
477
+ .unsqueeze(1)
478
+ .expand(B, num_rays, latent_dim * 3)
479
+ ) # [B, num_rays, latent_dim^2 + latent_dim]
480
+ else:
481
+ # this is matching the previous implementation of RENI, needed if using old checkpoints
482
+ z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1)
483
+ z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1)
484
+ return torch.cat(
485
+ (innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1
486
+ )
487
+
488
+ return directional_input, conditioning_input
489
+
490
+ if equivariance == "SO3":
491
+ G = Z @ torch.transpose(Z, 1, 2) # [B, latent_dim, latent_dim]
492
+ innerprod = torch.sum(
493
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
494
+ ) # [B, num_rays, latent_dim]
495
+ z_invar = (
496
+ G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1)
497
+ ) # [B, num_rays, latent_dim^2]
498
+ return innerprod, z_invar
499
+
500
+ def setup_network(self):
501
+ """Sets up the network architecture"""
502
+ base_input_dims = {
503
+ "VN": {
504
+ "None": {
505
+ "direction": self.latent_dim,
506
+ "conditioning": self.latent_dim * 3,
507
+ },
508
+ "SO2": {
509
+ "direction": self.latent_dim + 2,
510
+ "conditioning": self.latent_dim * 3,
511
+ },
512
+ "SO3": {
513
+ "direction": self.latent_dim,
514
+ "conditioning": self.latent_dim * 3,
515
+ },
516
+ },
517
+ "GramMatrix": {
518
+ "None": {
519
+ "direction": self.latent_dim,
520
+ "conditioning": self.latent_dim * 3,
521
+ },
522
+ "SO2": {
523
+ "direction": self.latent_dim + 2,
524
+ "conditioning": self.latent_dim**2 + self.latent_dim,
525
+ },
526
+ "SO3": {
527
+ "direction": self.latent_dim,
528
+ "conditioning": self.latent_dim**2,
529
+ },
530
+ },
531
+ }
532
+
533
+ # Extract the necessary input dimensions
534
+ input_types = ["direction", "conditioning"]
535
+ input_dims = {
536
+ key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][
537
+ key
538
+ ]
539
+ for key in input_types
540
+ }
541
+
542
+ # Helper function to create NeRF encoding
543
+ def create_nerf_encoding(in_dim):
544
+ return NeRFEncoding(
545
+ in_dim=in_dim,
546
+ num_frequencies=2,
547
+ min_freq_exp=0.0,
548
+ max_freq_exp=2.0,
549
+ include_input=True,
550
+ )
551
+
552
+ # Dictionary-based encoding setup
553
+ encoding_setup = {
554
+ "None": [],
555
+ "Conditioning": ["conditioning"],
556
+ "Directions": ["direction"],
557
+ "Both": ["direction", "conditioning"],
558
+ }
559
+
560
+ # Setting up the required encodings
561
+ for input_type in encoding_setup.get(self.cfg.encoded_input, []):
562
+ # create self.{input_type}_encoding and update input_dims
563
+ setattr(
564
+ self,
565
+ f"{input_type}_encoding",
566
+ create_nerf_encoding(input_dims[input_type]),
567
+ )
568
+ input_dims[input_type] = getattr(
569
+ self, f"{input_type}_encoding"
570
+ ).get_out_dim()
571
+
572
+ output_activation = get_activation_module(self.cfg.output_activation)
573
+
574
+ network = None
575
+ if self.conditioning == "Concat":
576
+ network = Siren(
577
+ in_dim=input_dims["direction"] + input_dims["conditioning"],
578
+ hidden_layers=self.hidden_layers,
579
+ hidden_features=self.hidden_features,
580
+ out_dim=self.out_features,
581
+ outermost_linear=self.last_layer_linear,
582
+ first_omega_0=self.first_omega_0,
583
+ hidden_omega_0=self.hidden_omega_0,
584
+ out_activation=output_activation,
585
+ )
586
+ elif self.conditioning == "FiLM":
587
+ network = FiLMSiren(
588
+ in_dim=input_dims["direction"],
589
+ hidden_layers=self.hidden_layers,
590
+ hidden_features=self.hidden_features,
591
+ mapping_network_in_dim=input_dims["conditioning"],
592
+ mapping_network_layers=self.mapping_layers,
593
+ mapping_network_features=self.mapping_features,
594
+ out_dim=self.out_features,
595
+ outermost_linear=True,
596
+ out_activation=output_activation,
597
+ )
598
+ elif self.conditioning == "Attention":
599
+ # transformer where K, V is from conditioning input and Q is from pos encoded directional input
600
+ network = Decoder(
601
+ in_dim=input_dims["direction"],
602
+ conditioning_input_dim=input_dims["conditioning"],
603
+ hidden_features=self.cfg.hidden_features,
604
+ num_heads=self.cfg.num_attention_heads,
605
+ num_layers=self.cfg.num_attention_layers,
606
+ out_activation=output_activation,
607
+ )
608
+ assert network is not None, "unknown conditioning type"
609
+ return network
610
+
611
+ def apply_positional_encoding(self, directional_input, conditioning_input):
612
+ # conditioning on just invariant directional input
613
+ if self.cfg.encoded_input == "Conditioning":
614
+ conditioning_input = self.conditioning_encoding(
615
+ conditioning_input
616
+ ) # [num_rays, embedding_dim]
617
+ elif self.cfg.encoded_input == "Directions":
618
+ directional_input = self.direction_encoding(
619
+ directional_input
620
+ ) # [num_rays, embedding_dim]
621
+ elif self.cfg.encoded_input == "Both":
622
+ directional_input = self.direction_encoding(directional_input)
623
+ conditioning_input = self.conditioning_encoding(conditioning_input)
624
+
625
+ return directional_input, conditioning_input
626
+
627
+ def get_outputs(
628
+ self,
629
+ rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
630
+ latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
631
+ rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
632
+ scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
633
+ ) -> Dict[str, Tensor]:
634
+ """Returns the outputs of the field.
635
+
636
+ Args:
637
+ ray_samples: [batch_size num_rays 3]
638
+ latent_codes: [batch_size, latent_dim, 3]
639
+ rotation: [batch_size, 3, 3]
640
+ scale: [batch_size]
641
+ """
642
+ if rotation is not None:
643
+ if len(rotation.shape) == 3: # [batch_size, 3, 3]
644
+ # Expand latent_codes to match [batch_size, latent_dim, 3]
645
+ latent_codes = torch.einsum(
646
+ "bik,blk->bli",
647
+ rotation,
648
+ latent_codes,
649
+ )
650
+ else:
651
+ raise NotImplementedError(
652
+ "Unsupported rotation shape. Expected [batch_size, 3, 3]."
653
+ )
654
+
655
+ B, num_rays, _ = rays_d.shape
656
+ _, latent_dim, _ = latent_codes.shape
657
+
658
+ if not self.old_implementation:
659
+ directional_input, conditioning_input = self.invariant_function(
660
+ latent_codes,
661
+ rays_d,
662
+ equivariance=self.equivariance,
663
+ axis_of_invariance=self.axis_of_invariance,
664
+ ) # [B, num_rays, 3]
665
+
666
+ if self.cfg.positional_encoding == "NeRF":
667
+ directional_input, conditioning_input = self.apply_positional_encoding(
668
+ directional_input, conditioning_input
669
+ )
670
+
671
+ if self.conditioning == "Concat":
672
+ model_outputs = self.network(
673
+ torch.cat((directional_input, conditioning_input), dim=-1).reshape(
674
+ B * num_rays, -1
675
+ )
676
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
677
+ elif self.conditioning == "FiLM":
678
+ model_outputs = self.network(
679
+ directional_input.reshape(B * num_rays, -1),
680
+ conditioning_input.reshape(B * num_rays, -1),
681
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
682
+ elif self.conditioning == "Attention":
683
+ model_outputs = self.network(
684
+ directional_input.reshape(B * num_rays, -1),
685
+ conditioning_input.reshape(B * num_rays, -1),
686
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
687
+ else:
688
+ # in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions
689
+ directions = torch.stack(
690
+ (rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1
691
+ )
692
+ model_input = self.invariant_function(
693
+ latent_codes,
694
+ directions,
695
+ equivariance=self.equivariance,
696
+ axis_of_invariance=self.axis_of_invariance,
697
+ ) # [B, num_rays, 3]
698
+
699
+ model_outputs = self.network(model_input.view(B * num_rays, -1)).view(
700
+ B, num_rays, 3
701
+ )
702
+
703
+ outputs = {}
704
+
705
+ if scale is not None:
706
+ scale = trunc_exp(scale) # [num_rays] exp to ensure positive
707
+ model_outputs = model_outputs * scale.view(-1, 1, 1) # [num_rays, 3]
708
+
709
+ outputs["rgb"] = model_outputs
710
+
711
+ return outputs
712
+
713
+ def forward(
714
+ self,
715
+ rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
716
+ latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
717
+ rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
718
+ scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
719
+ ) -> Dict[str, Tensor]:
720
+ """Evaluates spherical field for a given ray bundle and rotation.
721
+
722
+ Args:
723
+ ray_samples: [B num_rays 3]
724
+ latent_codes: [B, num_rays, latent_dim, 3]
725
+ rotation: [batch_size, 3, 3]
726
+ scale: [batch_size]
727
+
728
+ Returns:
729
+ Dict[str, Tensor]: A dictionary containing the outputs of the field.
730
+ """
731
+ return self.get_outputs(
732
+ rays_d=rays_d,
733
+ latent_codes=latent_codes,
734
+ rotation=rotation,
735
+ scale=scale,
736
+ )
spar3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import alpha_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from spar3d.models.network import get_activation
12
+ from spar3d.models.utils import BaseModule
13
+
14
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
15
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
16
+
17
+
18
+ @dataclass
19
+ class HeadSpec:
20
+ name: str
21
+ out_channels: int
22
+ n_hidden_layers: int
23
+ output_activation: Optional[str] = None
24
+ output_bias: float = 0.0
25
+ add_to_decoder_features: bool = False
26
+ shape: Optional[list[int]] = None
27
+ distribution_eval: str = "sample"
28
+
29
+
30
+ class ClipBasedHeadEstimator(BaseModule):
31
+ @dataclass
32
+ class Config(BaseModule.Config):
33
+ model: str = "ViT-L/14@336px"
34
+
35
+ distribution: str = "beta"
36
+
37
+ # ["mean", "mode", "sample", "sample_mean"]
38
+ distribution_eval: str = "mode"
39
+
40
+ activation: str = "relu"
41
+ hidden_features: int = 512
42
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
43
+
44
+ cfg: Config
45
+
46
+ def configure(self):
47
+ self.model, _ = alpha_clip.load(
48
+ self.cfg.model,
49
+ ) # change to your own ckpt path
50
+ self.model.eval()
51
+
52
+ if not hasattr(self.model.visual, "input_resolution"):
53
+ self.img_size = 224
54
+ else:
55
+ self.img_size = self.model.visual.input_resolution
56
+ # Check if img_size is subscribable and pick the first element
57
+ if hasattr(self.img_size, "__getitem__"):
58
+ self.img_size = self.img_size[0]
59
+
60
+ # Do not add the weights in self.model to the optimizer
61
+ for param in self.model.parameters():
62
+ param.requires_grad = False
63
+
64
+ assert len(self.cfg.heads) > 0
65
+ heads = {}
66
+ for head in self.cfg.heads:
67
+ head_layers = []
68
+ in_feature = self.model.visual.output_dim
69
+
70
+ for i in range(head.n_hidden_layers):
71
+ head_layers += [
72
+ nn.Linear(
73
+ in_feature if i == 0 else self.cfg.hidden_features,
74
+ self.cfg.hidden_features,
75
+ ),
76
+ self.make_activation(self.cfg.activation),
77
+ ]
78
+
79
+ head_layers = [nn.Sequential(*head_layers)]
80
+ head_layers += [
81
+ nn.Sequential(
82
+ nn.Linear(
83
+ self.cfg.hidden_features,
84
+ self.cfg.hidden_features,
85
+ ),
86
+ self.make_activation(self.cfg.activation),
87
+ nn.Linear(self.cfg.hidden_features, 1),
88
+ )
89
+ for _ in range(2)
90
+ ]
91
+ heads[head.name] = nn.ModuleList(head_layers)
92
+ self.heads = nn.ModuleDict(heads)
93
+
94
+ def make_activation(self, activation):
95
+ if activation == "relu":
96
+ return nn.ReLU(inplace=True)
97
+ elif activation == "silu":
98
+ return nn.SiLU(inplace=True)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ def forward(
103
+ self,
104
+ cond_image: Float[Tensor, "B 1 H W 4"],
105
+ sample: bool = True,
106
+ ) -> dict[str, Any]:
107
+ # Run the model
108
+ # Resize cond_image to 224
109
+ cond_image = cond_image.flatten(0, 1)
110
+ cond_image = nn.functional.interpolate(
111
+ cond_image.permute(0, 3, 1, 2),
112
+ size=(self.img_size, self.img_size),
113
+ mode="bilinear",
114
+ align_corners=False,
115
+ )
116
+ mask = cond_image[:, 3:4]
117
+ cond_image = cond_image[:, :3] * mask
118
+ cond_image = Normalize(
119
+ mean=OPENAI_DATASET_MEAN,
120
+ std=OPENAI_DATASET_STD,
121
+ )(cond_image)
122
+ mask = Normalize(0.5, 0.26)(mask).half()
123
+ image_features = self.model.visual(cond_image.half(), mask).float()
124
+
125
+ # Run the heads
126
+ outputs = {}
127
+
128
+ for head_dict in self.cfg.heads:
129
+ head_name = head_dict.name
130
+ shared_head, d1_h, d2_h = self.heads[head_name]
131
+ shared_features = shared_head(image_features)
132
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
133
+ if self.cfg.distribution == "normal":
134
+ mean = d1
135
+ var = d2
136
+ if mean.shape[-1] == 1:
137
+ outputs[head_name] = torch.distributions.Normal(
138
+ mean + head_dict.output_bias,
139
+ torch.nn.functional.softplus(var),
140
+ )
141
+ else:
142
+ outputs[head_name] = torch.distributions.MultivariateNormal(
143
+ mean + head_dict.output_bias,
144
+ torch.nn.functional.softplus(var).diag_embed(),
145
+ )
146
+ elif self.cfg.distribution == "beta":
147
+ outputs[head_name] = torch.distributions.Beta(
148
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
149
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
150
+ )
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ if sample:
155
+ for head_dict in self.cfg.heads:
156
+ head_name = head_dict.name
157
+ dist = outputs[head_name]
158
+
159
+ if head_dict.distribution_eval == "mean":
160
+ out = dist.mean
161
+ elif head_dict.distribution_eval == "mode":
162
+ out = dist.mode
163
+ elif head_dict.distribution_eval == "sample_mean":
164
+ out = dist.sample([10]).mean(-1)
165
+ else:
166
+ # use rsample if gradient is needed
167
+ out = dist.rsample() if self.training else dist.sample()
168
+
169
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
170
+ outputs[f"{head_name}_dist"] = dist
171
+
172
+ for head in self.cfg.heads:
173
+ if head.shape:
174
+ if not sample:
175
+ raise ValueError(
176
+ "Cannot reshape non-sampled probabilisitic outputs"
177
+ )
178
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
179
+
180
+ if head.add_to_decoder_features:
181
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
182
+ del outputs[head.name]
183
+
184
+ return outputs
spar3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
spar3d/models/mesh.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import trimesh
10
+ from jaxtyping import Float, Integer
11
+ from torch import Tensor
12
+
13
+ from spar3d.models.utils import dot
14
+
15
+ try:
16
+ from uv_unwrapper import Unwrapper
17
+ except ImportError:
18
+ import logging
19
+
20
+ logging.warning(
21
+ "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
22
+ )
23
+ # Exit early to avoid further errors
24
+ raise ImportError("uv_unwrapper not found")
25
+
26
+ try:
27
+ import gpytoolbox
28
+
29
+ TRIANGLE_REMESH_AVAILABLE = True
30
+ except ImportError:
31
+ TRIANGLE_REMESH_AVAILABLE = False
32
+ import logging
33
+
34
+ logging.warning(
35
+ "Could not import gpytoolbox. Triangle remeshing functionality will be disabled. "
36
+ "Install via `pip install gpytoolbox`"
37
+ )
38
+
39
+ try:
40
+ import pynim
41
+
42
+ QUAD_REMESH_AVAILABLE = True
43
+ except ImportError:
44
+ QUAD_REMESH_AVAILABLE = False
45
+ import logging
46
+
47
+ logging.warning(
48
+ "Could not import pynim. Quad remeshing functionality will be disabled. "
49
+ "Install via `pip install git+https://github.com/vork/[email protected]`"
50
+ )
51
+
52
+
53
+ class Mesh:
54
+ def __init__(
55
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
56
+ ) -> None:
57
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
58
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
59
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
60
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
61
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
62
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
63
+ self.extras: Dict[str, Any] = {}
64
+ for k, v in kwargs.items():
65
+ self.add_extra(k, v)
66
+
67
+ self.unwrapper = Unwrapper()
68
+
69
+ def add_extra(self, k, v) -> None:
70
+ self.extras[k] = v
71
+
72
+ @property
73
+ def requires_grad(self):
74
+ return self.v_pos.requires_grad
75
+
76
+ @property
77
+ def v_nrm(self):
78
+ if self._v_nrm is None:
79
+ self._v_nrm = self._compute_vertex_normal()
80
+ return self._v_nrm
81
+
82
+ @property
83
+ def v_tng(self):
84
+ if self._v_tng is None:
85
+ self._v_tng = self._compute_vertex_tangent()
86
+ return self._v_tng
87
+
88
+ @property
89
+ def v_tex(self):
90
+ if self._v_tex is None:
91
+ self.unwrap_uv()
92
+ return self._v_tex
93
+
94
+ @property
95
+ def edges(self):
96
+ if self._edges is None:
97
+ self._edges = self._compute_edges()
98
+ return self._edges
99
+
100
+ def _compute_vertex_normal(self):
101
+ i0 = self.t_pos_idx[:, 0]
102
+ i1 = self.t_pos_idx[:, 1]
103
+ i2 = self.t_pos_idx[:, 2]
104
+
105
+ v0 = self.v_pos[i0, :]
106
+ v1 = self.v_pos[i1, :]
107
+ v2 = self.v_pos[i2, :]
108
+
109
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
110
+
111
+ # Splat face normals to vertices
112
+ v_nrm = torch.zeros_like(self.v_pos)
113
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
114
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
115
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
116
+
117
+ # Normalize, replace zero (degenerated) normals with some default value
118
+ v_nrm = torch.where(
119
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
120
+ )
121
+ v_nrm = F.normalize(v_nrm, dim=1)
122
+
123
+ if torch.is_anomaly_enabled():
124
+ assert torch.all(torch.isfinite(v_nrm))
125
+
126
+ return v_nrm
127
+
128
+ def _compute_vertex_tangent(self):
129
+ vn_idx = [None] * 3
130
+ pos = [None] * 3
131
+ tex = [None] * 3
132
+ for i in range(0, 3):
133
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
134
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
135
+ # t_nrm_idx is always the same as t_pos_idx
136
+ vn_idx[i] = self.t_pos_idx[:, i]
137
+
138
+ tangents = torch.zeros_like(self.v_nrm)
139
+ tansum = torch.zeros_like(self.v_nrm)
140
+
141
+ # Compute tangent space for each triangle
142
+ duv1 = tex[1] - tex[0]
143
+ duv2 = tex[2] - tex[0]
144
+ dpos1 = pos[1] - pos[0]
145
+ dpos2 = pos[2] - pos[0]
146
+
147
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
148
+
149
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
150
+
151
+ # Avoid division by zero for degenerated texture coordinates
152
+ denom_safe = denom.clip(1e-6)
153
+ tang = tng_nom / denom_safe
154
+
155
+ # Update all 3 vertices
156
+ for i in range(0, 3):
157
+ idx = vn_idx[i][:, None].repeat(1, 3)
158
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
159
+ tansum.scatter_add_(
160
+ 0, idx, torch.ones_like(tang)
161
+ ) # tansum[n_i] = tansum[n_i] + 1
162
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
163
+ # triangles influence the tangent space more
164
+ tangents = tangents / tansum
165
+
166
+ # Normalize and make sure tangent is perpendicular to normal
167
+ tangents = F.normalize(tangents, dim=1)
168
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
169
+
170
+ if torch.is_anomaly_enabled():
171
+ assert torch.all(torch.isfinite(tangents))
172
+
173
+ return tangents
174
+
175
+ def quad_remesh(
176
+ self,
177
+ quad_vertex_count: int = -1,
178
+ quad_rosy: int = 4,
179
+ quad_crease_angle: float = -1.0,
180
+ quad_smooth_iter: int = 2,
181
+ quad_align_to_boundaries: bool = False,
182
+ ) -> Mesh:
183
+ if not QUAD_REMESH_AVAILABLE:
184
+ raise ImportError("Quad remeshing requires pynim to be installed")
185
+ if quad_vertex_count < 0:
186
+ quad_vertex_count = self.v_pos.shape[0]
187
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
188
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
189
+
190
+ new_vert, new_faces = pynim.remesh(
191
+ v_pos,
192
+ t_pos_idx,
193
+ quad_vertex_count // 4,
194
+ rosy=quad_rosy,
195
+ posy=4,
196
+ creaseAngle=quad_crease_angle,
197
+ align_to_boundaries=quad_align_to_boundaries,
198
+ smooth_iter=quad_smooth_iter,
199
+ deterministic=False,
200
+ )
201
+
202
+ # Briefly load in trimesh
203
+ mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
204
+
205
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
206
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
207
+
208
+ # Create new mesh
209
+ return Mesh(v_pos, t_pos_idx)
210
+
211
+ def triangle_remesh(
212
+ self,
213
+ triangle_average_edge_length_multiplier: Optional[float] = None,
214
+ triangle_remesh_steps: int = 10,
215
+ triangle_vertex_count=-1,
216
+ ):
217
+ if not TRIANGLE_REMESH_AVAILABLE:
218
+ raise ImportError("Triangle remeshing requires gpytoolbox to be installed")
219
+ if triangle_vertex_count > 0:
220
+ reduction = triangle_vertex_count / self.v_pos.shape[0]
221
+ print("Triangle reduction:", reduction)
222
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
223
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
224
+ if reduction > 1.0:
225
+ subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
226
+ print("Subdivide iters:", subdivide_iters)
227
+ v_pos, t_pos_idx = gpytoolbox.subdivide(
228
+ v_pos,
229
+ t_pos_idx,
230
+ iters=subdivide_iters,
231
+ )
232
+ reduction = triangle_vertex_count / v_pos.shape[0]
233
+
234
+ # Simplify
235
+ points_out, faces_out, _, _ = gpytoolbox.decimate(
236
+ v_pos,
237
+ t_pos_idx,
238
+ face_ratio=reduction,
239
+ )
240
+
241
+ # Convert back to torch
242
+ self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
243
+ self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
244
+ self._edges = None
245
+ triangle_average_edge_length_multiplier = None
246
+
247
+ edges = self.edges
248
+ if triangle_average_edge_length_multiplier is None:
249
+ h = None
250
+ else:
251
+ h = float(
252
+ torch.linalg.norm(
253
+ self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
254
+ )
255
+ .mean()
256
+ .item()
257
+ * triangle_average_edge_length_multiplier
258
+ )
259
+
260
+ # Convert to numpy
261
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
262
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
263
+
264
+ # Remesh
265
+ v_remesh, f_remesh = gpytoolbox.remesh_botsch(
266
+ v_pos,
267
+ t_pos_idx,
268
+ triangle_remesh_steps,
269
+ h,
270
+ )
271
+
272
+ # Convert back to torch
273
+ v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
274
+ t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
275
+
276
+ # Create new mesh
277
+ return Mesh(v_pos, t_pos_idx)
278
+
279
+ @torch.no_grad()
280
+ def unwrap_uv(
281
+ self,
282
+ island_padding: float = 0.02,
283
+ ) -> Mesh:
284
+ uv, indices = self.unwrapper(
285
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
286
+ )
287
+
288
+ # Do store per vertex UVs.
289
+ # This means we need to duplicate some vertices at the seams
290
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
291
+ individual_faces = torch.arange(
292
+ individual_vertices.shape[0],
293
+ device=individual_vertices.device,
294
+ dtype=self.t_pos_idx.dtype,
295
+ ).reshape(-1, 3)
296
+ uv_flat = uv[indices].reshape((-1, 2))
297
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
298
+
299
+ self.v_pos = individual_vertices
300
+ self.t_pos_idx = individual_faces
301
+ self._v_tex = uv_flat
302
+ self._v_nrm = self._compute_vertex_normal()
303
+ self._v_tng = self._compute_vertex_tangent()
304
+
305
+ def _compute_edges(self):
306
+ # Compute edges
307
+ edges = torch.cat(
308
+ [
309
+ self.t_pos_idx[:, [0, 1]],
310
+ self.t_pos_idx[:, [1, 2]],
311
+ self.t_pos_idx[:, [2, 0]],
312
+ ],
313
+ dim=0,
314
+ )
315
+ edges = edges.sort()[0]
316
+ edges = torch.unique(edges, dim=0)
317
+ return edges
spar3d/models/network.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.autograd import Function
11
+ from torch.cuda.amp import custom_bwd, custom_fwd
12
+
13
+ from spar3d.models.utils import BaseModule, normalize
14
+ from spar3d.utils import get_device
15
+
16
+
17
+ def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
18
+ def wrapper(fn):
19
+ if condition:
20
+ if len(kwargs) == 0:
21
+ return decorator_with_args
22
+ return decorator_with_args(*args, **kwargs)(fn)
23
+ else:
24
+ return fn
25
+
26
+ return wrapper
27
+
28
+
29
+ class PixelShuffleUpsampleNetwork(BaseModule):
30
+ @dataclass
31
+ class Config(BaseModule.Config):
32
+ in_channels: int = 1024
33
+ out_channels: int = 40
34
+ scale_factor: int = 4
35
+
36
+ conv_layers: int = 4
37
+ conv_kernel_size: int = 3
38
+
39
+ cfg: Config
40
+
41
+ def configure(self) -> None:
42
+ layers = []
43
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
44
+
45
+ in_channels = self.cfg.in_channels
46
+ for i in range(self.cfg.conv_layers):
47
+ cur_out_channels = (
48
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
49
+ )
50
+ layers.append(
51
+ nn.Conv2d(
52
+ in_channels,
53
+ cur_out_channels,
54
+ self.cfg.conv_kernel_size,
55
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
56
+ )
57
+ )
58
+ if i != self.cfg.conv_layers - 1:
59
+ layers.append(nn.ReLU(inplace=True))
60
+
61
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
62
+
63
+ self.upsample = nn.Sequential(*layers)
64
+
65
+ def forward(
66
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
67
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
68
+ return rearrange(
69
+ self.upsample(
70
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
71
+ ),
72
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
73
+ Np=3,
74
+ )
75
+
76
+
77
+ class _TruncExp(Function): # pylint: disable=abstract-method
78
+ # Implementation from torch-ngp:
79
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
+ @staticmethod
81
+ @conditional_decorator(
82
+ custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32
83
+ )
84
+ def forward(ctx, x): # pylint: disable=arguments-differ
85
+ ctx.save_for_backward(x)
86
+ return torch.exp(x)
87
+
88
+ @staticmethod
89
+ @conditional_decorator(custom_bwd, "cuda" in get_device())
90
+ def backward(ctx, g): # pylint: disable=arguments-differ
91
+ x = ctx.saved_tensors[0]
92
+ return g * torch.exp(torch.clamp(x, max=15))
93
+
94
+
95
+ trunc_exp = _TruncExp.apply
96
+
97
+
98
+ def get_activation(name) -> Callable:
99
+ if name is None:
100
+ return lambda x: x
101
+ name = name.lower()
102
+ if name == "none" or name == "linear" or name == "identity":
103
+ return lambda x: x
104
+ elif name == "lin2srgb":
105
+ return lambda x: torch.where(
106
+ x > 0.0031308,
107
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
108
+ 12.92 * x,
109
+ ).clamp(0.0, 1.0)
110
+ elif name == "exp":
111
+ return lambda x: torch.exp(x)
112
+ elif name == "shifted_exp":
113
+ return lambda x: torch.exp(x - 1.0)
114
+ elif name == "trunc_exp":
115
+ return trunc_exp
116
+ elif name == "shifted_trunc_exp":
117
+ return lambda x: trunc_exp(x - 1.0)
118
+ elif name == "sigmoid":
119
+ return lambda x: torch.sigmoid(x)
120
+ elif name == "tanh":
121
+ return lambda x: torch.tanh(x)
122
+ elif name == "shifted_softplus":
123
+ return lambda x: F.softplus(x - 1.0)
124
+ elif name == "scale_-11_01":
125
+ return lambda x: x * 0.5 + 0.5
126
+ elif name == "negative":
127
+ return lambda x: -x
128
+ elif name == "normalize_channel_last":
129
+ return lambda x: normalize(x)
130
+ elif name == "normalize_channel_first":
131
+ return lambda x: normalize(x, dim=1)
132
+ else:
133
+ try:
134
+ return getattr(F, name)
135
+ except AttributeError:
136
+ raise ValueError(f"Unknown activation function: {name}")
137
+
138
+
139
+ class LambdaModule(torch.nn.Module):
140
+ def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]):
141
+ super().__init__()
142
+ self.lambd = lambd
143
+
144
+ def forward(self, x):
145
+ return self.lambd(x)
146
+
147
+
148
+ def get_activation_module(name) -> torch.nn.Module:
149
+ return LambdaModule(get_activation(name))
150
+
151
+
152
+ @dataclass
153
+ class HeadSpec:
154
+ name: str
155
+ out_channels: int
156
+ n_hidden_layers: int
157
+ output_activation: Optional[str] = None
158
+ out_bias: float = 0.0
159
+
160
+
161
+ class MaterialMLP(BaseModule):
162
+ @dataclass
163
+ class Config(BaseModule.Config):
164
+ in_channels: int = 120
165
+ n_neurons: int = 64
166
+ activation: str = "silu"
167
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
168
+
169
+ cfg: Config
170
+
171
+ def configure(self) -> None:
172
+ assert len(self.cfg.heads) > 0
173
+ heads = {}
174
+ for head in self.cfg.heads:
175
+ head_layers = []
176
+ for i in range(head.n_hidden_layers):
177
+ head_layers += [
178
+ nn.Linear(
179
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
180
+ self.cfg.n_neurons,
181
+ ),
182
+ self.make_activation(self.cfg.activation),
183
+ ]
184
+ head_layers += [
185
+ nn.Linear(
186
+ self.cfg.n_neurons,
187
+ head.out_channels,
188
+ ),
189
+ ]
190
+ heads[head.name] = nn.Sequential(*head_layers)
191
+ self.heads = nn.ModuleDict(heads)
192
+
193
+ def make_activation(self, activation):
194
+ if activation == "relu":
195
+ return nn.ReLU(inplace=True)
196
+ elif activation == "silu":
197
+ return nn.SiLU(inplace=True)
198
+ else:
199
+ raise NotImplementedError
200
+
201
+ def keys(self):
202
+ return self.heads.keys()
203
+
204
+ def forward(
205
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
206
+ ):
207
+ if include is not None and exclude is not None:
208
+ raise ValueError("Cannot specify both include and exclude.")
209
+ if include is not None:
210
+ heads = [h for h in self.cfg.heads if h.name in include]
211
+ elif exclude is not None:
212
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
213
+ else:
214
+ heads = self.cfg.heads
215
+
216
+ out = {
217
+ head.name: get_activation(head.output_activation)(
218
+ self.heads[head.name](x) + head.out_bias
219
+ )
220
+ for head in heads
221
+ }
222
+
223
+ return out
spar3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
spar3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from spar3d.models.transformers.attention import Modulation
12
+ from spar3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
+ if modulation_cond is not None
86
+ else None,
87
+ )
88
+ local_features = out.last_hidden_state
89
+ local_features = local_features.permute(0, 2, 1)
90
+ local_features = rearrange(
91
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
+ )
93
+ if packed:
94
+ local_features = local_features.squeeze(1)
95
+
96
+ return local_features
97
+
98
+ def detokenize(self, *args, **kwargs):
99
+ raise NotImplementedError
spar3d/models/tokenizers/point.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from spar3d.models.transformers.transformer_1d import Transformer1D
9
+ from spar3d.models.utils import BaseModule
10
+
11
+
12
+ class TransformerPointTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ num_attention_heads: int = 16
16
+ attention_head_dim: int = 64
17
+ in_channels: Optional[int] = 6
18
+ out_channels: Optional[int] = 1024
19
+ num_layers: int = 16
20
+ norm_num_groups: int = 32
21
+ attention_bias: bool = False
22
+ activation_fn: str = "geglu"
23
+ norm_elementwise_affine: bool = True
24
+
25
+ cfg: Config
26
+
27
+ def configure(self) -> None:
28
+ transformer_cfg = dict(self.cfg.copy())
29
+ # remove the non-transformer configs
30
+ transformer_cfg["in_channels"] = (
31
+ self.cfg.num_attention_heads * self.cfg.attention_head_dim
32
+ )
33
+ self.model = Transformer1D(transformer_cfg)
34
+ self.linear_in = torch.nn.Linear(
35
+ self.cfg.in_channels, transformer_cfg["in_channels"]
36
+ )
37
+ self.linear_out = torch.nn.Linear(
38
+ transformer_cfg["in_channels"], self.cfg.out_channels
39
+ )
40
+
41
+ def forward(
42
+ self, points: Float[Tensor, "B N Ci"], **kwargs
43
+ ) -> Float[Tensor, "B N Cp"]:
44
+ assert points.ndim == 3
45
+ inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
46
+ out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
47
+ out = self.linear_out(out) # B N Ci -> B N Co
48
+ return out
49
+
50
+ def detokenize(self, *args, **kwargs):
51
+ raise NotImplementedError
spar3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )