Model-assisted labeling Python script
Updated
by Alex Cota
This script showcases the basic functionality of the Model-assisted labeling workflow. This script is broken into two parts. For an overview of the MAL workflow, see Model-assisted labeling (import annotations).
- Create a project, dataset, ontology, and select a labeling frontend.
- Turn on MAL, get annotation schemas, and create annotations on the data row.
Note: In order to run this script, you will need create an API key.
from labelbox import Client
from labelbox import Project
from labelbox import Dataset
import json
import os
from labelbox.schema.bulk_import_request import BulkImportRequest
from labelbox.schema.enums import BulkImportRequestState
import requests
import ndjson
API_KEY = "<API KEY>"
#IMPORT_NAME must be unique per project
IMPORT_NAME = "<IMPORT NAME>"
def get_project_ontology(project_id: str) -> dict:
"""
Gets the ontology of the given project
Args:
project_id (str): The id of the project
Returns:
The ontology of the project in a dict format
"""
res_str = client.execute("""
query get_ontology($proj_id: ID!) {
project(where: {id: $proj_id}) {
ontology {
normalized
}
}
}
""", {"proj_id": project_id})
return res_str
def turn_on_model_assisted_labeling(client: Client, project_id: str) -> None:
"""
Turns model assisted labeling on for the given project
Args:
client (Client): The client that is connected via API key
project_id (str): The id of the project
Returns:
None
"""
client.execute("""
mutation TurnPredictionsOn($proj_id: ID!){
project(
where: {id: $proj_id}
){
showPredictionsToLabelers(show:true{
id
showingPredictionsToLabelers
}
}
}
""", {"proj_id": project_id})
def get_schema_ids(ontology: dict) -> dict:
"""
Gets the schema id's of each tool given an ontology
Args:
ontology (dict): The ontology that we are looking to parse the schema id's from
Returns:
A dict containing the tool name and the schema information
"""
schemas = {}
for tool in ontology['tools']:
schema = {
'schemaNodeId': tool['featureSchemaId'],
'color': tool['color'],
'tooltype':tool['tool']
}
schemas[tool['name']] = schema
return schemas
"""
PART ONE.
This will do the following:
1. Create a new project named "New MAL Project"
2. Create a new dataset named "New MAL Dataset" and attach it to the project
3. Find an updated Editor frontend and attach it to the project
4. Create a new ontology and attach it to the project
"""
client = Client(API_KEY)
new_project = client.create_project(name="New MAL Project")
new_dataset = client.create_dataset(name="New MAL Dataset", projects = new_project)
new_dataset.create_data_row(row_data="https://storage.googleapis.com/labelbox-sample-datasets/sample-mapillary/lb-segment-data_validation_images_--BJs76vloEaiH-wppzWNA.jpg")
all_frontends = list(client.get_labeling_frontends())
for frontend in all_frontends:
if frontend.name == 'Editor':
new_project_frontend = frontend
break
new_project.labeling_frontend.connect(new_project_frontend)
new_project_ontology = "{\"tools\": [{ \"required\": false, \"name\": \"polygon tool\", \"tool\": \"polygon\", \"color\": \"navy\", \"classifications\": []}, { \"required\": false, \"name\": \"segmentation tool\", \"tool\": \"superpixel\", \"color\": \"#1CE6FF\", \"classifications\": []}, { \"required\": false, \"name\": \"point tool\", \"tool\": \"point\", \"color\": \"#FF4A46\", \"classifications\": []}, { \"required\": false, \"name\": \"bbox tool\", \"tool\": \"rectangle\", \"color\": \"#008941\", \"classifications\": []}, { \"required\": false, \"name\": \"polyline tool\", \"tool\": \"line\", \"color\": \"#006FA6\", \"classifications\": []}], \"classifications\": [{ \"required\": false, \"instructions\": \"Are there classification options?\", \"name\": \"classification options\", \"type\": \"radio\", \"options\": [{ \"label\": \"Yes\", \"value\": \"yes\"}, { \"label\": \"Definitely\", \"value\": \"definitely\"}, { \"label\": \"Third one?!\", \"value\": \"third one?!\"}]}]}"
new_project.setup(new_project_frontend, new_project_ontology)
new_project.datasets.connect(new_dataset)
print(f"The project id is: {new_project.uid}")
print(f"The dataset id is: {new_dataset.uid}")
"""
PART TWO.
This will do the following:
1. Turn on model assisted labeling for the project
2. Query for the existing ontology
3. Get the schemas from the queried ontology
4. Get the datarow that we want to annotate on
5. Create a list of annotations for each tool
6. Upload the annotations
7. Provide errors, if any
Note: importing annotations is not immediate and can take a few minutes.
If you would like to track while it is importing, include the following lines:
import logging
logging.basicConfig(level = logging.INFO)
"""
client = Client(API_KEY)
project_for_mal = client.get_project(new_project.uid)
dataset_for_mal = client.get_dataset(new_dataset.uid)
turn_on_model_assisted_labeling(client = client, project_id = project_for_mal.uid)
ontology = get_project_ontology(project_for_mal.uid)['project']['ontology']['normalized']
schemas = get_schema_ids(ontology)
datarow_id = list(dataset_for_mal.data_rows())[0].uid
annotations = [
{
"uuid": "d6fc18e4-13ed-11eb-8e85-acde48001122",
"schemaId": schemas['polygon tool']['schemaNodeId'],
"dataRow": {"id": datarow_id},
"polygon": [
{"x": 132.536, "y": 73.217},
{"x": 177.494, "y": 69.363},
{"x": 243.004, "y": 93.769},
{"x": 198.046, "y": 208.09},
{"x": 105.562, "y": 140.011}
]
},
{
"uuid": "d6fc1a88-13ed-11eb-8e85-acde48001122",
"schemaId": schemas['segmentation tool']['schemaNodeId'],
"dataRow": {"id": datarow_id},
"mask": {
"instanceURI": "https://storage.googleapis.com/labelbox-sample-datasets/sample-mapillary/lb-segment-data_validation_labels_--BJs76vloEaiH-wppzWNA_mask.png", "colorRGB": [255, 255, 255]
}
},
{
"uuid": "d6fc1ac4-13ed-11eb-8e85-acde48001122",
"schemaId": schemas['point tool']['schemaNodeId'],
"dataRow": {"id": datarow_id},
"point": {"x": 176, "y": 128}
},
{
"uuid": "d6fc1aec-13ed-11eb-8e85-acde48001122",
"schemaId": schemas['bbox tool']['schemaNodeId'],
"dataRow": {"id": datarow_id},
"bbox": {
"top": 48,
"left": 58,
"height": 213,
"width": 215
}
},
{
"uuid": "d6fc1b0a-13ed-11eb-8e85-acde48001122",
"schemaId": schemas['polyline tool']['schemaNodeId'],
"dataRow": {"id": datarow_id},
"line": [
{"x": 163.364, "y": 21.837},
{"x": 269.978, "y": 59.087},
{"x": 205.753, "y": 146.433},
{"x": 225.021, "y": 175.977},
{"x": 149.235, "y": 240.202},
{"x": 85.01, "y": 169.554},
{"x": 123.545, "y": 104.045},
{"x": 82.441, "y": 74.501},
{"x": 120.976, "y": 21.837}
]
}
]
project_for_mal.upload_annotations(annotations = annotations, name = IMPORT_NAME)
upload_job = BulkImportRequest.from_name(client, project_id = project_for_mal.uid, name = IMPORT_NAME)
upload_job.wait_until_done()
print(f"The annotation import is: {upload_job.state}")
if upload_job.error_file_url:
res = requests.get(upload_job.error_file_url)
errors = ndjson.loads(res.text)
print("\nErrors:")
for error in errors:
print(
"An annotation failed to import for "
f"datarow: {error['dataRow']} due to: "
f"{error['errors']}")