|
|
|
|
|
import os |
|
import random |
|
from urllib.parse import parse_qs |
|
|
|
import gradio as gr |
|
import requests |
|
from transformers import pipeline |
|
from huggingface_hub import Repository |
|
from dotenv import load_dotenv |
|
from pathlib import Path |
|
import json |
|
from utils import force_git_push |
|
import threading |
|
|
|
|
|
if Path(".env").is_file(): |
|
load_dotenv(".env") |
|
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL") |
|
FORCE_PUSH = os.getenv("FORCE_PUSH") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
DATA_FILENAME = "data.jsonl" |
|
DATA_FILE = os.path.join("data", DATA_FILENAME) |
|
repo = Repository( |
|
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN |
|
) |
|
|
|
TOTAL_CNT = 2 |
|
|
|
|
|
|
|
PUSH_FREQUENCY = 60 |
|
def asynchronous_push(f_stop): |
|
if repo.is_repo_clean(): |
|
print("Repo currently clean. Ignoring push_to_hub") |
|
else: |
|
repo.git_add(auto_lfs_track=True) |
|
repo.git_commit("Auto commit by space") |
|
if FORCE_PUSH == "yes": |
|
force_git_push(repo) |
|
else: |
|
repo.git_push() |
|
if not f_stop.is_set(): |
|
|
|
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start() |
|
|
|
f_stop = threading.Event() |
|
asynchronous_push(f_stop) |
|
|
|
|
|
pipe = pipeline("sentiment-analysis") |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
dummy = gr.Textbox(visible=False) |
|
|
|
|
|
state_dict = {"assignmentId": "", "cnt": 0, "cnt_fooled": 0, "data": []} |
|
state = gr.JSON(state_dict, visible=False) |
|
|
|
gr.Markdown("# DADC in Gradio example. See the README.md to run this space seamlessly in Amazon Mechanical Turk.") |
|
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!") |
|
|
|
state_display = gr.Markdown(f"State: 0/{TOTAL_CNT} (0 fooled)") |
|
|
|
|
|
|
|
def _predict(txt, tgt, state, dummy): |
|
pred = pipe(txt)[0] |
|
other_label = 'negative' if pred['label'].lower() == "positive" else "positive" |
|
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']} |
|
|
|
pred["label"] = pred["label"].title() |
|
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n" |
|
fooled = pred["label"] != tgt |
|
if fooled: |
|
state["cnt_fooled"] += 1 |
|
ret += " You fooled the model! Well done!" |
|
else: |
|
ret += " You did not fool the model! Too bad, try again!" |
|
state["cnt"] += 1 |
|
|
|
done = state["cnt"] == TOTAL_CNT |
|
toggle_example_submit = gr.update(visible=not done) |
|
new_state_md = f"State: {state['cnt']}/{TOTAL_CNT} ({state['cnt_fooled']} fooled)" |
|
|
|
state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt.lower(), "model_pred": pred["label"].lower(), "fooled": fooled}) |
|
|
|
query = parse_qs(dummy[1:]) |
|
if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE": |
|
|
|
|
|
|
|
|
|
state["assignmentId"] = query["assignmentId"][0] |
|
toggle_final_submit = gr.update(visible=done) |
|
toggle_final_submit_preview = gr.update(visible=False) |
|
else: |
|
toggle_final_submit_preview = gr.update(visible=done) |
|
toggle_final_submit = gr.update(visible=False) |
|
|
|
if state["cnt"] == TOTAL_CNT: |
|
|
|
|
|
with open(DATA_FILE, "a") as jsonlfile: |
|
json_data_with_assignment_id =\ |
|
[json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]] |
|
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n") |
|
|
|
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, toggle_final_submit_preview, new_state_md, dummy |
|
|
|
|
|
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False) |
|
labels = ["Positive", "Negative"] |
|
random.shuffle(labels) |
|
label_input = gr.Radio(choices=labels, label="Target (correct) label") |
|
label_output = gr.Label() |
|
text_output = gr.Markdown() |
|
with gr.Column() as example_submit: |
|
submit_ex_button = gr.Button("Submit") |
|
with gr.Column(visible=False) as final_submit: |
|
submit_hit_button = gr.Button("Submit HIT") |
|
with gr.Column(visible=False) as final_submit_preview: |
|
submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit, but your examples will still be stored)") |
|
|
|
|
|
get_window_location_search_js = """ |
|
function(text_input, label_input, state, dummy) { |
|
return [text_input, label_input, state, window.location.search]; |
|
} |
|
""" |
|
|
|
submit_ex_button.click( |
|
_predict, |
|
inputs=[text_input, label_input, state, dummy], |
|
outputs=[label_output, text_output, state, example_submit, final_submit, final_submit_preview, state_display, dummy], |
|
_js=get_window_location_search_js, |
|
) |
|
|
|
post_hit_js = """ |
|
function(state) { |
|
// If there is an assignmentId, then the submitter is on mturk |
|
// and has accepted the HIT. So, we need to submit their HIT. |
|
const form = document.createElement('form'); |
|
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit'; |
|
form.method = 'post'; |
|
for (const key in state) { |
|
const hiddenField = document.createElement('input'); |
|
hiddenField.type = 'hidden'; |
|
hiddenField.name = key; |
|
hiddenField.value = state[key]; |
|
form.appendChild(hiddenField); |
|
}; |
|
document.body.appendChild(form); |
|
form.submit(); |
|
return state; |
|
} |
|
""" |
|
|
|
submit_hit_button.click( |
|
lambda state: state, |
|
inputs=[state], |
|
outputs=[state], |
|
_js=post_hit_js, |
|
) |
|
|
|
refresh_app_js = """ |
|
function(state) { |
|
// The following line here loads the app again so the user can |
|
// enter in another preview-mode "HIT". |
|
window.location.href = window.location.href; |
|
return state; |
|
} |
|
""" |
|
|
|
submit_hit_button_preview.click( |
|
lambda state: state, |
|
inputs=[state], |
|
outputs=[state], |
|
_js=refresh_app_js, |
|
) |
|
|
|
demo.launch() |