Working With Sagemaker Part-II

This article talks about a built-in way of can training and testing your model in Sagemaker.

Training:

Training a model is pretty easy using sagemaker built-in tools. All you have to do is select a built-in algorithm from a list provided in here and create, and estimator with the required training details. You can also set hyperparameters specific to your model - for more information visit this page.

Here is a minimal code to help you get started. Remember the role to be passed in estimator is the IAM role (and this is assumed to be the current role in this case) and the output path is declared to store the artifacts of the trained model.

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer
session = sagemaker.Session()
role = get_execution_role()
container = get_image_uri(session.boto_region_name, 'xgboost')
xgb = sagemaker.estimator.Estimator(container, role,train_instance_count=1,
                                    train_instance_type='ml.m4.xlarge',
                                    output_path='s3://{}/{}/output'.format(session.default_bucket(), prefix),                                          
                                    sagemaker_session=session)
xgb.set_hyperparameters(max_depth=5,
                        gamma=4,
                        objective='reg:linear',
                        early_stopping_rounds=10,
                        num_round=200)

s3_input_train = sagemaker.s3_input(s3_data=train_location, content_type='csv')
s3_input_validation = sagemaker.s3_input(s3_data=val_location, content_type='csv')

xgb.fit({'train': s3_input_train, 'validation': s3_input_validation})

Sagemaker also provide tools for hyperparameter tuning. This can be done after the fit function is called.

from sagemaker.tuner import IntegerParameter, ContinuousParameter, HyperparameterTuner

xgb_hyperparameter_tuner = HyperparameterTuner(estimator = xgb, # The estimator object to use as the basis for the training jobs.
                                               objective_metric_name = 'validation:rmse', # The metric used to compare trained models.
                                               objective_type = 'Minimize', # Whether we wish to minimize or maximize the metric.
                                               max_jobs = 20, # The total number of models to train
                                               max_parallel_jobs = 3, # The number of models to train in parallel
                                               hyperparameter_ranges = {
                                                    'max_depth': IntegerParameter(3, 12),
                                                    'gamma': ContinuousParameter(0, 10),
                                               })
xgb_hyperparameter_tuner.wait()

#After this, you can get the information about best training model by:
xgb_attached = sagemaker.estimator.Estimator.attach(xgb_hyperparameter_tuner.best_training_job())


Testing:

Testing a model is also a straight-forward process. There is this batch-transform functionality that allows you to create your transformer object associated with the trained model and begins a batch transform job to analyse test data stored in S3. Have a look at an example code:

xgb_transformer = xgb.transformer(instance_count = 1, instance_type = 'ml.m4.xlarge')
xgb_transformer.transform(test_location, content_type='text/csv', split_type='Line')
xgb_transformer.wait()

#test
Y_pred = pd.read_csv(os.path.join(data_dir, 'test.csv.out'), header=None) 

#NOTE: If you have applied hyperparameter tuning and attached it, then use the "attached" version
xgb_transformer = xgb_attached.transformer(instance_count = 1, instance_type = 'ml.m4.xlarge')

You can view your training jobs and hyperparameter tuning jobs to check their status and output logs by navigating to Training section in Sagemaker sidebar.

 

JBS

Leave a Reply

Your email address will not be published. Required fields are marked *