ssalb commited on
Commit
139df7a
·
1 Parent(s): 7c0d92c

Update space with latest code and dependencies on Mon Jan 13 19:33:51 UTC 2025

Browse files
README.md CHANGED
@@ -9,7 +9,7 @@ sdk_version: 5.9.1
9
  app_file: app.py
10
  pinned: false
11
  preload_from_hub:
12
- - openai-community/gpt2
13
  - answerdotai/ModernBERT-base
14
  - facebook/bart-large-mnli
15
  license: mit
@@ -17,8 +17,8 @@ license: mit
17
 
18
  ## Project Overview
19
 
20
- The Story Generator project leverages advanced natural language processing models to generate coherent and engaging stories. By utilizing models such as GPT-2, BERT, and BART, this project aims to provide users with a tool to create narratives based on given prompts. The application is built using Gradio for an interactive user interface, making it easy to input prompts and receive generated stories in real-time.
21
 
22
- The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
23
 
24
  Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
 
9
  app_file: app.py
10
  pinned: false
11
  preload_from_hub:
12
+ - HuggingFaceTB/SmolLM2-135M-Instruct
13
  - answerdotai/ModernBERT-base
14
  - facebook/bart-large-mnli
15
  license: mit
 
17
 
18
  ## Project Overview
19
 
20
+ This Story Generator leverages natural language processing models to generate coherent and engaging stories. By utilizing models such as SmolLMv2, BERT, and BART, this project aims to provide users with a tool to create narratives based on given prompts. The application is built using Gradio for an interactive user interface, making it easy to input prompts and receive generated stories in real-time.
21
 
22
+ The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores in a process-based reward model (PRM) fashion. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
23
 
24
  Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  accelerate==1.2.1 ; python_full_version == "3.10.13"
2
  aiofiles==23.2.1 ; python_full_version == "3.10.13"
3
  annotated-types==0.7.0 ; python_full_version == "3.10.13"
4
- anyio==4.7.0 ; python_full_version == "3.10.13"
5
  certifi==2024.12.14 ; python_full_version == "3.10.13"
6
  charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
7
  click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
@@ -16,7 +16,7 @@ gradio==5.9.1 ; python_full_version == "3.10.13"
16
  h11==0.14.0 ; python_full_version == "3.10.13"
17
  httpcore==1.0.7 ; python_full_version == "3.10.13"
18
  httpx==0.28.1 ; python_full_version == "3.10.13"
19
- huggingface-hub==0.27.0 ; python_full_version == "3.10.13"
20
  idna==3.10 ; python_full_version == "3.10.13"
21
  jinja2==3.1.5 ; python_full_version == "3.10.13"
22
  joblib==1.4.2 ; python_full_version == "3.10.13"
@@ -26,16 +26,16 @@ mdurl==0.1.2 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
26
  mpmath==1.3.0 ; python_full_version == "3.10.13"
27
  networkx==3.4.2 ; python_full_version == "3.10.13"
28
  numpy==2.2.1 ; python_full_version == "3.10.13"
29
- orjson==3.10.13 ; python_full_version == "3.10.13"
30
  packaging==24.2 ; python_full_version == "3.10.13"
31
  pandas==2.2.3 ; python_full_version == "3.10.13"
32
  pillow==11.1.0 ; python_full_version == "3.10.13"
33
- protobuf==5.29.2 ; python_full_version == "3.10.13"
34
  psutil==6.1.1 ; python_full_version == "3.10.13"
35
  pydantic-core==2.27.2 ; python_full_version == "3.10.13"
36
- pydantic==2.10.4 ; python_full_version == "3.10.13"
37
  pydub==0.25.1 ; python_full_version == "3.10.13"
38
- pygments==2.18.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
39
  python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
40
  python-multipart==0.0.20 ; python_full_version == "3.10.13"
41
  pytz==2024.2 ; python_full_version == "3.10.13"
@@ -43,11 +43,11 @@ pyyaml==6.0.2 ; python_full_version == "3.10.13"
43
  regex==2024.11.6 ; python_full_version == "3.10.13"
44
  requests==2.32.3 ; python_full_version == "3.10.13"
45
  rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
46
- ruff==0.8.5 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
47
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
48
- safetensors==0.5.0 ; python_full_version == "3.10.13"
49
- scikit-learn==1.6.0 ; python_full_version == "3.10.13"
50
- scipy==1.15.0 ; python_full_version == "3.10.13"
51
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
52
  shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
53
  six==1.17.0 ; python_full_version == "3.10.13"
@@ -59,7 +59,7 @@ tokenizers==0.21.0 ; python_full_version == "3.10.13"
59
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
60
  torch==2.4.0 ; python_full_version == "3.10.13"
61
  tqdm==4.67.1 ; python_full_version == "3.10.13"
62
- transformers @ git+https://github.com/huggingface/transformers.git@e5fd865ebae062b7cf03a81b8c6affeb39f30bec ; python_full_version == "3.10.13"
63
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
64
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
65
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
 
1
  accelerate==1.2.1 ; python_full_version == "3.10.13"
2
  aiofiles==23.2.1 ; python_full_version == "3.10.13"
3
  annotated-types==0.7.0 ; python_full_version == "3.10.13"
4
+ anyio==4.8.0 ; python_full_version == "3.10.13"
5
  certifi==2024.12.14 ; python_full_version == "3.10.13"
6
  charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
7
  click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
 
16
  h11==0.14.0 ; python_full_version == "3.10.13"
17
  httpcore==1.0.7 ; python_full_version == "3.10.13"
18
  httpx==0.28.1 ; python_full_version == "3.10.13"
19
+ huggingface-hub==0.27.1 ; python_full_version == "3.10.13"
20
  idna==3.10 ; python_full_version == "3.10.13"
21
  jinja2==3.1.5 ; python_full_version == "3.10.13"
22
  joblib==1.4.2 ; python_full_version == "3.10.13"
 
26
  mpmath==1.3.0 ; python_full_version == "3.10.13"
27
  networkx==3.4.2 ; python_full_version == "3.10.13"
28
  numpy==2.2.1 ; python_full_version == "3.10.13"
29
+ orjson==3.10.14 ; python_full_version == "3.10.13"
30
  packaging==24.2 ; python_full_version == "3.10.13"
31
  pandas==2.2.3 ; python_full_version == "3.10.13"
32
  pillow==11.1.0 ; python_full_version == "3.10.13"
33
+ protobuf==5.29.3 ; python_full_version == "3.10.13"
34
  psutil==6.1.1 ; python_full_version == "3.10.13"
35
  pydantic-core==2.27.2 ; python_full_version == "3.10.13"
36
+ pydantic==2.10.5 ; python_full_version == "3.10.13"
37
  pydub==0.25.1 ; python_full_version == "3.10.13"
38
+ pygments==2.19.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
39
  python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
40
  python-multipart==0.0.20 ; python_full_version == "3.10.13"
41
  pytz==2024.2 ; python_full_version == "3.10.13"
 
43
  regex==2024.11.6 ; python_full_version == "3.10.13"
44
  requests==2.32.3 ; python_full_version == "3.10.13"
45
  rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
46
+ ruff==0.9.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
47
  safehttpx==0.1.6 ; python_full_version == "3.10.13"
48
+ safetensors==0.5.2 ; python_full_version == "3.10.13"
49
+ scikit-learn==1.6.1 ; python_full_version == "3.10.13"
50
+ scipy==1.15.1 ; python_full_version == "3.10.13"
51
  semantic-version==2.10.0 ; python_full_version == "3.10.13"
52
  shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
53
  six==1.17.0 ; python_full_version == "3.10.13"
 
59
  tomlkit==0.13.2 ; python_full_version == "3.10.13"
60
  torch==2.4.0 ; python_full_version == "3.10.13"
61
  tqdm==4.67.1 ; python_full_version == "3.10.13"
62
+ transformers==4.48.0 ; python_full_version == "3.10.13"
63
  triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
64
  typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
65
  typing-extensions==4.12.2 ; python_full_version == "3.10.13"
story_beam_search/scoring.py CHANGED
@@ -71,7 +71,7 @@ class CoherenceScorer(StoryScorer):
71
  outputs = self.model(**inputs, output_hidden_states=True)
72
  batch_embeddings = outputs.hidden_states[-1][
73
  :, 0, :
74
- ] # Get CLS token embeddings
75
  all_embeddings.extend(batch_embeddings.cpu().numpy())
76
 
77
  # Calculate coherence scores for each story
 
71
  outputs = self.model(**inputs, output_hidden_states=True)
72
  batch_embeddings = outputs.hidden_states[-1][
73
  :, 0, :
74
+ ]
75
  all_embeddings.extend(batch_embeddings.cpu().numpy())
76
 
77
  # Calculate coherence scores for each story
story_beam_search/stories_generator.py CHANGED
@@ -16,7 +16,7 @@ auth_token = os.getenv("HF_TOKEN", None)
16
 
17
  @dataclass
18
  class ModelConfig:
19
- text_model_name: str = "openai-community/gpt2"
20
  bert_name: str = "answerdotai/ModernBERT-base"
21
  zero_shot_name: str = "facebook/bart-large-mnli"
22
  device: str = (
 
16
 
17
  @dataclass
18
  class ModelConfig:
19
+ text_model_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
20
  bert_name: str = "answerdotai/ModernBERT-base"
21
  zero_shot_name: str = "facebook/bart-large-mnli"
22
  device: str = (