16 February 2021

Sagemaker: Hyperparameter tuning

Sagemaker: Hyperparameter tuning

This post is about setting up hyper-parameter tuning jobs on AWS Sagemaker. The post assumes that Sagemaker is configured based on the Sagemaker Basic Setup guide from this blog.

Here, we cover the following topics:

  • Script to start tuning jobs
  • Search space for hyper-parameters
  • Automated workflow through CircleCI

Script to start tuning jobs

In the basic Sagemaker setup, we created a file called run_sagemaker.py that calls the Sagemaker API to start a training job. To start a tuning job, we create a similar file run_sagemaker_tuner.py, where we also first define an Estimator object, and give it as input to another object of class HyperparameterTuner:

 

import os
from sagemaker.estimator import Estimator
import sagemaker


from sagemaker.tuner import IntegerParameter, HyperparameterTuner, ContinuousParameter

# TODO: change me
BUCKET_NAME = "MY_BUCKET"
REPO_NAME = "REPO_NAME"

s3_model_output_location = f"s3://{BUCKET_NAME}/sagemaker/{REPO_NAME}"

sess = sagemaker.Session(default_bucket=BUCKET_NAME)
role = os.environ["SAGEMAKER_ROLE"]
tag = os.environ.get("CIRCLE_BRANCH") or "latest"
account_url = os.environ["AWS_ECR_ACCOUNT_URL"]

tf_estimator = Estimator(
    role=role,
    instance_count=1,
    instance_type="ml.g4dn.2xlarge",
    sagemaker_session=sess,
    base_job_name=tag,
    output_path=s3_model_output_location,
    image_uri=f"{account_url}/{REPO_NAME}:{tag}",
    metric_definitions=[
        {"Name": "train:loss", "Regex": "loss: (.*?) "},
        {"Name": "val:loss", "Regex": "val_loss: (.*?) "},
        {"Name": "train:precision", "Regex": "precision: (.*?) "},
        {"Name": "val:precision", "Regex": "val_precision: (.*?) "},
        {"Name": "train:recall", "Regex": "recall: (.*?) "},
        {"Name": "val:recall", "Regex": "val_recall: (.*?) "},
    ],
    use_spot_instances=True,
    max_wait=3600,  # 1 hour
    max_run=3600,
)

static_hyperparams = {
    "epochs": 150,
    "model_version": "3.0",
}

tf_estimator.set_hyperparameters(**static_hyperparams)

tuned_hyperparam_ranges = {
    "batch_size": IntegerParameter(10, 100),
    "dropout_rate": ContinuousParameter(0.2, 0.9),
    "decay_steps": IntegerParameter(10, 1000),
    "initial_learning_rate": ContinuousParameter(0.001, 0.5),
}

tuner = HyperparameterTuner(
    estimator=tf_estimator,
    metric_definitions=[{"Name": "val:loss", "Regex": "val_loss: (.*?) "}],
    objective_metric_name="val:loss",
    objective_type="Minimize",
    hyperparameter_ranges=tuned_hyperparam_ranges,
    max_jobs=100,
    max_parallel_jobs=2,
    early_stopping_type="Auto",
    base_tuning_job_name=f"TUNER-{tag}",
)

# will create ENV variables based on keys -- SM_CHANNEL_XXX
tuner.fit(
    inputs={
        "data": f"s3://{BUCKET_NAME}/training_data/{REPO_NAME}",
    },
    wait=False,
)

The Estimator object is covered in the basic Sagemaker setup post in detail, so here we focus on the HyperparameterTuner class:

  • estimator – estimator object we created previously.
  • metric_definitions – regular expressions defining the objective metric(s) we want to monitor. Refer to the monitoring and debugging Sagemaker blog post for details.
  • objective_metric_name – specific metric that will be optimised by the tuner. Should be the name from the metric_definitions list.
  • objective_type – whether we want to minimise or maximise the objective metric.
  • hyperparameter_ranges – ranges of hyperparameters we want to optimise. You can see them defined above in a straightforward way. There are integer, continuous (float) and categorical types of parameters.
  • max_jobs and max_parallel_jobs – maximum number of total and parallel jobs to run. Each job is a normal Sagemaker training job. With default AWS permissions you can't run more than 2 parallel jobs. To increase the limit, you need to create a support ticket:
    https://aws.amazon.com/premiumsupport/knowledge-center/resourcelimitexceeded-sagemaker/
  • early_stopping_type – off by default, set to auto to allow Sagemaker algorithms to stop the tuning job early.
  • base_tuning_job_name – similar to the base job name for a single training job. The prefix is limited to 20 characters.

The fit method of the tuner is then called with the same arguments as for a single training job.

Results

After starting the job, you can see the individual training jobs and the progress in the Sagemaker interface:

You can also view the current best training job to view the best hyperparameters:

CircleCI config
CircleCI integration is configured similarly to a standard training job:

sagemaker_tuning:
docker:
  - image: circleci/python:3.7
steps:
  - setup_remote_docker:
      docker_layer_caching: true
  - checkout
  - aws-cli/setup:
      aws-region: AWS_REGION
  - run: pip install sagemaker
  - run:
      name: Run training script
      command: python ./sagemaker/run_sagemaker_tuner.py

That’s all!