Model-assisted labeling Python script

Alex Cota Updated by Alex Cota

This script showcases the basic functionality of the Model-assisted labeling workflow. This script is broken into two parts. For an overview of the MAL workflow, see Model-assisted labeling (import annotations).

  1. Create a project, dataset, ontology, and select a labeling frontend.
  2. Turn on MAL, get annotation schemas, and create annotations on the data row.

Note: In order to run this script, you will need create an API key.

from labelbox import Client
from labelbox import Project
from labelbox import Dataset
import json
import os
from labelbox.schema.bulk_import_request import BulkImportRequest
from labelbox.schema.enums import BulkImportRequestState
import requests
import ndjson

API_KEY = "<API KEY>"
#IMPORT_NAME must be unique per project
IMPORT_NAME = "<IMPORT NAME>"

def get_project_ontology(project_id: str) -> dict:
    """
    Gets the ontology of the given project

    Args:
        project_id (str): The id of the project
    Returns:
        The ontology of the project in a dict format
    """
    res_str = client.execute("""
                    query get_ontology($proj_id: ID!) {
                        project(where: {id: $proj_id}) {
                            ontology {
                                normalized
                            }
                        }
                    }
                """, {"proj_id": project_id})
    return res_str

def turn_on_model_assisted_labeling(client: Client, project_id: str) -> None:
    """
    Turns model assisted labeling on for the given project

    Args:
        client (Client): The client that is connected via API key
        project_id (str): The id of the project
    Returns:
        None

    """
    client.execute("""
         mutation TurnPredictionsOn($proj_id: ID!){
             project(
                 where: {id: $proj_id}
             ){
                 showPredictionsToLabelers(show:true{
                     id
                     showingPredictionsToLabelers
                 }
             }
         }
     """, {"proj_id": project_id})

def get_schema_ids(ontology: dict) -> dict:
    """
    Gets the schema id's of each tool given an ontology
    
Args:
        ontology (dict): The ontology that we are looking to parse the schema id's from
    Returns:
        A dict containing the tool name and the schema information
    """
    schemas = {}
    for tool in ontology['tools']:
        schema = {
            'schemaNodeId': tool['featureSchemaId'],
            'color': tool['color'],
            'tooltype':tool['tool']
                    }
        schemas[tool['name']] = schema
    return schemas

"""
PART ONE.

This will do the following:
    1. Create a new project named "New MAL Project"
    2. Create a new dataset named "New MAL Dataset" and attach it to the project
    3. Find an updated Editor frontend and attach it to the project
    4. Create a new ontology and attach it to the project
"""

client = Client(API_KEY)
new_project = client.create_project(name="New MAL Project")

new_dataset = client.create_dataset(name="New MAL Dataset", projects = new_project)
new_dataset.create_data_row(row_data="https://storage.googleapis.com/labelbox-sample-datasets/sample-mapillary/lb-segment-data_validation_images_--BJs76vloEaiH-wppzWNA.jpg")

all_frontends = list(client.get_labeling_frontends())
for frontend in all_frontends:
    if frontend.name == 'Editor':
        new_project_frontend = frontend
        break

new_project.labeling_frontend.connect(new_project_frontend)
new_project_ontology = "{\"tools\": [{  \"required\": false, \"name\": \"polygon tool\", \"tool\": \"polygon\", \"color\": \"navy\", \"classifications\": []}, {  \"required\": false, \"name\": \"segmentation tool\", \"tool\": \"superpixel\", \"color\": \"#1CE6FF\", \"classifications\": []}, {  \"required\": false, \"name\": \"point tool\", \"tool\": \"point\", \"color\": \"#FF4A46\", \"classifications\": []}, {  \"required\": false, \"name\": \"bbox tool\", \"tool\": \"rectangle\", \"color\": \"#008941\", \"classifications\": []}, {  \"required\": false, \"name\": \"polyline tool\", \"tool\": \"line\", \"color\": \"#006FA6\", \"classifications\": []}], \"classifications\": [{  \"required\": false, \"instructions\": \"Are there classification options?\", \"name\": \"classification options\", \"type\": \"radio\", \"options\": [{  \"label\": \"Yes\", \"value\": \"yes\"}, {  \"label\": \"Definitely\", \"value\": \"definitely\"}, {  \"label\": \"Third one?!\", \"value\": \"third one?!\"}]}]}"
new_project.setup(new_project_frontend, new_project_ontology)

new_project.datasets.connect(new_dataset)

print(f"The project id is: {new_project.uid}")
print(f"The dataset id is: {new_dataset.uid}")


"""
PART TWO.

This will do the following:
    1. Turn on model assisted labeling for the project
    2. Query for the existing ontology
    3. Get the schemas from the queried ontology
    4. Get the datarow that we want to annotate on
    5. Create a list of annotations for each tool
    6. Upload the annotations
    7. Provide errors, if any

Note: importing annotations is not immediate and can take a few minutes.
     If you would like to track while it is importing, include the following lines:

import logging
logging.basicConfig(level = logging.INFO)
"""
client = Client(API_KEY)

project_for_mal = client.get_project(new_project.uid)
dataset_for_mal = client.get_dataset(new_dataset.uid)
turn_on_model_assisted_labeling(client = client, project_id = project_for_mal.uid)

ontology = get_project_ontology(project_for_mal.uid)['project']['ontology']['normalized']

schemas = get_schema_ids(ontology)

datarow_id = list(dataset_for_mal.data_rows())[0].uid

annotations = [
    {
         "uuid": "d6fc18e4-13ed-11eb-8e85-acde48001122",
         "schemaId": schemas['polygon tool']['schemaNodeId'],
         "dataRow": {"id": datarow_id},
         "polygon": [
             {"x": 132.536, "y": 73.217},
             {"x": 177.494, "y": 69.363},
             {"x": 243.004, "y": 93.769},
             {"x": 198.046, "y": 208.09},
             {"x": 105.562, "y": 140.011}
]
    },
    {
         "uuid": "d6fc1a88-13ed-11eb-8e85-acde48001122",
         "schemaId": schemas['segmentation tool']['schemaNodeId'],
         "dataRow": {"id": datarow_id},
         "mask": {
             "instanceURI": "https://storage.googleapis.com/labelbox-sample-datasets/sample-mapillary/lb-segment-data_validation_labels_--BJs76vloEaiH-wppzWNA_mask.png",              "colorRGB": [255, 255, 255]
     }
},
    {
         "uuid": "d6fc1ac4-13ed-11eb-8e85-acde48001122",
         "schemaId": schemas['point tool']['schemaNodeId'],
         "dataRow": {"id": datarow_id},
         "point": {"x": 176, "y": 128}
    },
    {
         "uuid": "d6fc1aec-13ed-11eb-8e85-acde48001122",
         "schemaId": schemas['bbox tool']['schemaNodeId'],
         "dataRow": {"id": datarow_id},
         "bbox": {
             "top": 48,
             "left": 58,
             "height": 213,
             "width": 215
         }
    },
    {
         "uuid": "d6fc1b0a-13ed-11eb-8e85-acde48001122",
         "schemaId": schemas['polyline tool']['schemaNodeId'],
         "dataRow": {"id": datarow_id},
         "line": [
             {"x": 163.364, "y": 21.837},
             {"x": 269.978, "y": 59.087},
             {"x": 205.753, "y": 146.433},
             {"x": 225.021, "y": 175.977},
             {"x": 149.235, "y": 240.202},
             {"x": 85.01, "y": 169.554},
             {"x": 123.545, "y": 104.045},
             {"x": 82.441, "y": 74.501},
             {"x": 120.976, "y": 21.837}
]
    }
]

project_for_mal.upload_annotations(annotations = annotations, name = IMPORT_NAME)
upload_job = BulkImportRequest.from_name(client, project_id = project_for_mal.uid, name = IMPORT_NAME)
upload_job.wait_until_done()

print(f"The annotation import is: {upload_job.state}")

if upload_job.error_file_url:
    res = requests.get(upload_job.error_file_url)
    errors = ndjson.loads(res.text)
    print("\nErrors:")
    for error in errors:
         print(
             "An annotation failed to import for "
             f"datarow: {error['dataRow']} due to: "
             f"{error['errors']}")

Was this page helpful?

API reference

Contact