dadc / app.py
Tristan's picture
Update app.py
2d57d72
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import os
import random
from urllib.parse import parse_qs
import gradio as gr
import requests
from transformers import pipeline
from huggingface_hub import Repository
from dotenv import load_dotenv
from pathlib import Path
import json
from utils import force_git_push
import threading
# These variables are for storing the mturk HITs in a Hugging Face dataset.
if Path(".env").is_file():
load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
FORCE_PUSH = os.getenv("FORCE_PUSH")
HF_TOKEN = os.getenv("HF_TOKEN")
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
TOTAL_CNT = 2 # How many examples per HIT
# This function pushes the HIT data written in data.jsonl to our Hugging Face
# dataset every minute. Adjust the frequency to suit your needs.
PUSH_FREQUENCY = 60
def asynchronous_push(f_stop):
if repo.is_repo_clean():
print("Repo currently clean. Ignoring push_to_hub")
else:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Auto commit by space")
if FORCE_PUSH == "yes":
force_git_push(repo)
else:
repo.git_push()
if not f_stop.is_set():
# call again in 60 seconds
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
f_stop = threading.Event()
asynchronous_push(f_stop)
# Now let's run the app!
pipe = pipeline("sentiment-analysis")
demo = gr.Blocks()
with demo:
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
# We keep track of state as a JSON
state_dict = {"assignmentId": "", "cnt": 0, "cnt_fooled": 0, "data": []}
state = gr.JSON(state_dict, visible=False)
gr.Markdown("# DADC in Gradio example. See the README.md to run this space seamlessly in Amazon Mechanical Turk.")
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")
state_display = gr.Markdown(f"State: 0/{TOTAL_CNT} (0 fooled)")
# Generate model prediction
# Default model: distilbert-base-uncased-finetuned-sst-2-english
def _predict(txt, tgt, state, dummy):
pred = pipe(txt)[0]
other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}
pred["label"] = pred["label"].title()
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
fooled = pred["label"] != tgt
if fooled:
state["cnt_fooled"] += 1
ret += " You fooled the model! Well done!"
else:
ret += " You did not fool the model! Too bad, try again!"
state["cnt"] += 1
done = state["cnt"] == TOTAL_CNT
toggle_example_submit = gr.update(visible=not done)
new_state_md = f"State: {state['cnt']}/{TOTAL_CNT} ({state['cnt_fooled']} fooled)"
state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt.lower(), "model_pred": pred["label"].lower(), "fooled": fooled})
query = parse_qs(dummy[1:])
if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
# It seems that someone is using this app on mturk. We need to
# store the assignmentId in the state before submit_hit_button
# is clicked. We can do this here in _predict. We need to save the
# assignmentId so that the turker can get credit for their HIT.
state["assignmentId"] = query["assignmentId"][0]
toggle_final_submit = gr.update(visible=done)
toggle_final_submit_preview = gr.update(visible=False)
else:
toggle_final_submit_preview = gr.update(visible=done)
toggle_final_submit = gr.update(visible=False)
if state["cnt"] == TOTAL_CNT:
# Write the HIT data to our local dataset because the worker has
# submitted everything now.
with open(DATA_FILE, "a") as jsonlfile:
json_data_with_assignment_id =\
[json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]]
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, toggle_final_submit_preview, new_state_md, dummy
# Input fields
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
labels = ["Positive", "Negative"]
random.shuffle(labels)
label_input = gr.Radio(choices=labels, label="Target (correct) label")
label_output = gr.Label()
text_output = gr.Markdown()
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
with gr.Column(visible=False) as final_submit_preview:
submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit, but your examples will still be stored)")
# Button event handlers
get_window_location_search_js = """
function(text_input, label_input, state, dummy) {
return [text_input, label_input, state, window.location.search];
}
"""
submit_ex_button.click(
_predict,
inputs=[text_input, label_input, state, dummy],
outputs=[label_output, text_output, state, example_submit, final_submit, final_submit_preview, state_display, dummy],
_js=get_window_location_search_js,
)
post_hit_js = """
function(state) {
// If there is an assignmentId, then the submitter is on mturk
// and has accepted the HIT. So, we need to submit their HIT.
const form = document.createElement('form');
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
form.method = 'post';
for (const key in state) {
const hiddenField = document.createElement('input');
hiddenField.type = 'hidden';
hiddenField.name = key;
hiddenField.value = state[key];
form.appendChild(hiddenField);
};
document.body.appendChild(form);
form.submit();
return state;
}
"""
submit_hit_button.click(
lambda state: state,
inputs=[state],
outputs=[state],
_js=post_hit_js,
)
refresh_app_js = """
function(state) {
// The following line here loads the app again so the user can
// enter in another preview-mode "HIT".
window.location.href = window.location.href;
return state;
}
"""
submit_hit_button_preview.click(
lambda state: state,
inputs=[state],
outputs=[state],
_js=refresh_app_js,
)
demo.launch()