Spaces:
Paused
Paused
ssalb
commited on
Commit
·
139df7a
1
Parent(s):
7c0d92c
Update space with latest code and dependencies on Mon Jan 13 19:33:51 UTC 2025
Browse files- README.md +3 -3
- requirements.txt +11 -11
- story_beam_search/scoring.py +1 -1
- story_beam_search/stories_generator.py +1 -1
README.md
CHANGED
@@ -9,7 +9,7 @@ sdk_version: 5.9.1
|
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
preload_from_hub:
|
12 |
-
-
|
13 |
- answerdotai/ModernBERT-base
|
14 |
- facebook/bart-large-mnli
|
15 |
license: mit
|
@@ -17,8 +17,8 @@ license: mit
|
|
17 |
|
18 |
## Project Overview
|
19 |
|
20 |
-
|
21 |
|
22 |
-
The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
|
23 |
|
24 |
Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
|
|
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
preload_from_hub:
|
12 |
+
- HuggingFaceTB/SmolLM2-135M-Instruct
|
13 |
- answerdotai/ModernBERT-base
|
14 |
- facebook/bart-large-mnli
|
15 |
license: mit
|
|
|
17 |
|
18 |
## Project Overview
|
19 |
|
20 |
+
This Story Generator leverages natural language processing models to generate coherent and engaging stories. By utilizing models such as SmolLMv2, BERT, and BART, this project aims to provide users with a tool to create narratives based on given prompts. The application is built using Gradio for an interactive user interface, making it easy to input prompts and receive generated stories in real-time.
|
21 |
|
22 |
+
The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores in a process-based reward model (PRM) fashion. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
|
23 |
|
24 |
Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
accelerate==1.2.1 ; python_full_version == "3.10.13"
|
2 |
aiofiles==23.2.1 ; python_full_version == "3.10.13"
|
3 |
annotated-types==0.7.0 ; python_full_version == "3.10.13"
|
4 |
-
anyio==4.
|
5 |
certifi==2024.12.14 ; python_full_version == "3.10.13"
|
6 |
charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
|
7 |
click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
@@ -16,7 +16,7 @@ gradio==5.9.1 ; python_full_version == "3.10.13"
|
|
16 |
h11==0.14.0 ; python_full_version == "3.10.13"
|
17 |
httpcore==1.0.7 ; python_full_version == "3.10.13"
|
18 |
httpx==0.28.1 ; python_full_version == "3.10.13"
|
19 |
-
huggingface-hub==0.27.
|
20 |
idna==3.10 ; python_full_version == "3.10.13"
|
21 |
jinja2==3.1.5 ; python_full_version == "3.10.13"
|
22 |
joblib==1.4.2 ; python_full_version == "3.10.13"
|
@@ -26,16 +26,16 @@ mdurl==0.1.2 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
|
26 |
mpmath==1.3.0 ; python_full_version == "3.10.13"
|
27 |
networkx==3.4.2 ; python_full_version == "3.10.13"
|
28 |
numpy==2.2.1 ; python_full_version == "3.10.13"
|
29 |
-
orjson==3.10.
|
30 |
packaging==24.2 ; python_full_version == "3.10.13"
|
31 |
pandas==2.2.3 ; python_full_version == "3.10.13"
|
32 |
pillow==11.1.0 ; python_full_version == "3.10.13"
|
33 |
-
protobuf==5.29.
|
34 |
psutil==6.1.1 ; python_full_version == "3.10.13"
|
35 |
pydantic-core==2.27.2 ; python_full_version == "3.10.13"
|
36 |
-
pydantic==2.10.
|
37 |
pydub==0.25.1 ; python_full_version == "3.10.13"
|
38 |
-
pygments==2.
|
39 |
python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
|
40 |
python-multipart==0.0.20 ; python_full_version == "3.10.13"
|
41 |
pytz==2024.2 ; python_full_version == "3.10.13"
|
@@ -43,11 +43,11 @@ pyyaml==6.0.2 ; python_full_version == "3.10.13"
|
|
43 |
regex==2024.11.6 ; python_full_version == "3.10.13"
|
44 |
requests==2.32.3 ; python_full_version == "3.10.13"
|
45 |
rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
46 |
-
ruff==0.
|
47 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
48 |
-
safetensors==0.5.
|
49 |
-
scikit-learn==1.6.
|
50 |
-
scipy==1.15.
|
51 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
52 |
shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
53 |
six==1.17.0 ; python_full_version == "3.10.13"
|
@@ -59,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
|
|
59 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
60 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
61 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
62 |
-
transformers
|
63 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
64 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
65 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
|
|
1 |
accelerate==1.2.1 ; python_full_version == "3.10.13"
|
2 |
aiofiles==23.2.1 ; python_full_version == "3.10.13"
|
3 |
annotated-types==0.7.0 ; python_full_version == "3.10.13"
|
4 |
+
anyio==4.8.0 ; python_full_version == "3.10.13"
|
5 |
certifi==2024.12.14 ; python_full_version == "3.10.13"
|
6 |
charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
|
7 |
click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
|
|
16 |
h11==0.14.0 ; python_full_version == "3.10.13"
|
17 |
httpcore==1.0.7 ; python_full_version == "3.10.13"
|
18 |
httpx==0.28.1 ; python_full_version == "3.10.13"
|
19 |
+
huggingface-hub==0.27.1 ; python_full_version == "3.10.13"
|
20 |
idna==3.10 ; python_full_version == "3.10.13"
|
21 |
jinja2==3.1.5 ; python_full_version == "3.10.13"
|
22 |
joblib==1.4.2 ; python_full_version == "3.10.13"
|
|
|
26 |
mpmath==1.3.0 ; python_full_version == "3.10.13"
|
27 |
networkx==3.4.2 ; python_full_version == "3.10.13"
|
28 |
numpy==2.2.1 ; python_full_version == "3.10.13"
|
29 |
+
orjson==3.10.14 ; python_full_version == "3.10.13"
|
30 |
packaging==24.2 ; python_full_version == "3.10.13"
|
31 |
pandas==2.2.3 ; python_full_version == "3.10.13"
|
32 |
pillow==11.1.0 ; python_full_version == "3.10.13"
|
33 |
+
protobuf==5.29.3 ; python_full_version == "3.10.13"
|
34 |
psutil==6.1.1 ; python_full_version == "3.10.13"
|
35 |
pydantic-core==2.27.2 ; python_full_version == "3.10.13"
|
36 |
+
pydantic==2.10.5 ; python_full_version == "3.10.13"
|
37 |
pydub==0.25.1 ; python_full_version == "3.10.13"
|
38 |
+
pygments==2.19.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
39 |
python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
|
40 |
python-multipart==0.0.20 ; python_full_version == "3.10.13"
|
41 |
pytz==2024.2 ; python_full_version == "3.10.13"
|
|
|
43 |
regex==2024.11.6 ; python_full_version == "3.10.13"
|
44 |
requests==2.32.3 ; python_full_version == "3.10.13"
|
45 |
rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
46 |
+
ruff==0.9.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
47 |
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
48 |
+
safetensors==0.5.2 ; python_full_version == "3.10.13"
|
49 |
+
scikit-learn==1.6.1 ; python_full_version == "3.10.13"
|
50 |
+
scipy==1.15.1 ; python_full_version == "3.10.13"
|
51 |
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
52 |
shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
53 |
six==1.17.0 ; python_full_version == "3.10.13"
|
|
|
59 |
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
60 |
torch==2.4.0 ; python_full_version == "3.10.13"
|
61 |
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
62 |
+
transformers==4.48.0 ; python_full_version == "3.10.13"
|
63 |
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
64 |
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
65 |
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
story_beam_search/scoring.py
CHANGED
@@ -71,7 +71,7 @@ class CoherenceScorer(StoryScorer):
|
|
71 |
outputs = self.model(**inputs, output_hidden_states=True)
|
72 |
batch_embeddings = outputs.hidden_states[-1][
|
73 |
:, 0, :
|
74 |
-
]
|
75 |
all_embeddings.extend(batch_embeddings.cpu().numpy())
|
76 |
|
77 |
# Calculate coherence scores for each story
|
|
|
71 |
outputs = self.model(**inputs, output_hidden_states=True)
|
72 |
batch_embeddings = outputs.hidden_states[-1][
|
73 |
:, 0, :
|
74 |
+
]
|
75 |
all_embeddings.extend(batch_embeddings.cpu().numpy())
|
76 |
|
77 |
# Calculate coherence scores for each story
|
story_beam_search/stories_generator.py
CHANGED
@@ -16,7 +16,7 @@ auth_token = os.getenv("HF_TOKEN", None)
|
|
16 |
|
17 |
@dataclass
|
18 |
class ModelConfig:
|
19 |
-
text_model_name: str = "
|
20 |
bert_name: str = "answerdotai/ModernBERT-base"
|
21 |
zero_shot_name: str = "facebook/bart-large-mnli"
|
22 |
device: str = (
|
|
|
16 |
|
17 |
@dataclass
|
18 |
class ModelConfig:
|
19 |
+
text_model_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
20 |
bert_name: str = "answerdotai/ModernBERT-base"
|
21 |
zero_shot_name: str = "facebook/bart-large-mnli"
|
22 |
device: str = (
|