检查训练结果#
trainer.fit() 的返回值是一个 Result 对象。
该 Result 对象除了包含其他信息外,还包含:
最后报告的检查点(用于加载模型)及其附带的指标
如果发生任何错误,则包含错误消息
查看指标#
您可以从 Result 对象中检索附加到检查点的已报告指标。
常见的指标包括训练或验证损失,或预测准确率。
从 Result 对象中检索到的指标对应于您在 训练函数中 作为参数传递给 train.report 的指标。
注意
通过 ray.train.report(metrics, checkpoint=None) 报告的独立指标的持久化已弃用。这也意味着从 Result 对象中检索这些指标也已弃用。只有附加到检查点的指标才会被持久化。有关更多详细信息,请参阅 (已弃用) 报告仅指标。
最后报告的指标#
使用 Result.metrics 来检索附加到最后报告的检查点的指标。
result = trainer.fit()
print("Observed metrics:", result.metrics)
所有报告指标的 DataFrame#
使用 Result.metrics_dataframe 来检索与检查点一起报告的所有指标的 pandas DataFrame。
df = result.metrics_dataframe
print("Minimum loss", min(df["loss"]))
检索检查点#
您可以从 Result 对象中检索报告给 Ray Train 的检查点。
检查点 包含恢复训练状态所需的所有信息。这通常包括训练好的模型。
您可以使用检查点来执行常见的下游任务,例如 使用 Ray Data 进行离线批量推理 或 使用 Ray Serve 进行在线模型服务。
从 Result 对象中检索到的检查点对应于您在 训练函数中 作为参数传递给 train.report 的检查点。
最后保存的检查点#
使用 Result.checkpoint 来检索最后一个检查点。
print("Last checkpoint:", result.checkpoint)
with result.checkpoint.as_directory() as tmpdir:
# Load model from directory
...
其他检查点#
有时您可能需要访问更早的检查点。例如,如果您的损失在更多训练后由于过拟合而增加,您可能希望检索具有最低损失的检查点。
您可以使用 Result.best_checkpoints 来检索所有可用检查点及其指标的列表。
# Print available checkpoints
for checkpoint, metrics in result.best_checkpoints:
print("Loss", metrics["loss"], "checkpoint", checkpoint)
# Get checkpoint with minimal loss
best_checkpoint = min(
result.best_checkpoints, key=lambda checkpoint: checkpoint[1]["loss"]
)[0]
with best_checkpoint.as_directory() as tmpdir:
# Load model from directory
...
另请参阅
有关检查点设置的更多信息,请参阅 保存和加载检查点。
访问存储位置#
如果您以后需要检索结果,可以使用 Result.path 获取训练运行的存储位置。
此路径将对应于您在 RunConfig 中配置的 storage_path。它将是该路径下的一个(嵌套)子目录,通常格式为 TrainerName_date-string/TrainerName_id_00000_0_...。
结果还包含一个 pyarrow.fs.FileSystem,可用于访问存储位置,这在路径位于云存储时非常有用。
result_path: str = result.path
result_filesystem: pyarrow.fs.FileSystem = result.filesystem
print(f"Results location (fs, path) = ({result_filesystem}, {result_path})")
捕获错误#
如果在训练过程中发生错误,Result.error 将被设置并包含引发的异常。
try:
result = trainer.fit()
except ray.train.TrainingFailedError as e:
if isinstance(e, ray.train.WorkerGroupError):
print(e.worker_failures)
在持久化存储上查找结果#
所有训练结果,包括报告的指标和检查点,都存储在配置的 持久化存储 上。
请参阅 持久化存储指南 来为您的训练运行配置此位置。