using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace ImageClassification
{
class Program
{
static void Main(string[] args)
{
// 设置训练数据和模型保存的路径
var trainDataPath = Path.Combine(Environment.CurrentDirectory, "train");
var modelPath = Path.Combine(Environment.CurrentDirectory, "model.zip");
// 初始化ML.NET环境
var context = new MLContext();
// 加载训练数据
var trainData = context.Data.LoadFromImageFolder<ImageData>(trainDataPath, true);
// 数据预处理
var pipeline = context.Transforms.Conversion.MapValueToKey("Label", "Label")
.Append(context.Transforms.ResizeImages("input", 224, 224))
.Append(context.Transforms.ExtractPixels("input", interleavePixelColors: true))
.Append(context.Transforms.Normalize("input", new[] {
new NormalizingEstimator.MinMaxColumn("input",
new[] { 0.0f, 255.0f })
}))
.Append(context.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
// 定义模型训练管道
var options = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "input",
LabelColumnName = "Label",
Arch = ImageClassificationTrainer.Architecture.ResnetV250,
Epoch = 100,
BatchSize = 10,
LearningRate = 0.01f,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
};
var trainer = context.MulticlassClassification.Trainers.ImageClassification(options)
.Append(context.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
var trainingPipeline = pipeline.Append(trainer);
// 训练模型
var model = trainingPipeline.Fit(trainData);
// 保存模型
context.Model.Save(model, trainData.Schema, modelPath);
Console.WriteLine("Training completed!");
}
}
public class ImageData
{
[ImageType(224, 224)]
public BitmapImage Image { get; set; }
public string Label { get; set; }
}
public class ImagePrediction
{
[ColumnName("PredictedLabel")]
public string Prediction { get; set; }
}
}
此示例代码使用ResNetv250模型,对训练集进行100个Epochs的训练,并将训练后的模型保存到本地。你可以根据需要进行修改。注意,此示例中使用的BitmapImage类型需要在Microsoft.ML.Data命名空间下引用。